diff --git a/BUILD.bazel b/BUILD.bazel index 7f6ec2d1db7..b8525a1e36c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -207,6 +207,7 @@ MICROKERNEL_DEFS = [ "src/x16-transposec/x16-transposec.h", "src/x24-transposec/x24-transposec.h", "src/x32-packb/x32-packb.h", + "src/x32-pack-lh/x32-pack-lh.h", "src/x32-packw/x32-packw.h", "src/x32-packx/x32-packx.h", "src/x32-transposec/x32-transposec.h", @@ -231,6 +232,7 @@ MICROKERNEL_HDRS = [ "src/xnnpack/packw.h", "src/xnnpack/packx.h", "src/xnnpack/pad.h", + "src/xnnpack/pack-lh.h", "src/xnnpack/pavgpool.h", "src/xnnpack/ppmm.h", "src/xnnpack/quantization.h", @@ -805,8 +807,9 @@ xnnpack_cxx_library( "@KleidiAI//kai/ukernels/matmul", "@KleidiAI//kai/ukernels/matmul:rhs_pack_kxn_qsi4cxp_qs4cxs1s0", "@KleidiAI//kai/ukernels/matmul:rhs_pack_nxk_qsi4cxp_qs4cxs1s0", - "@KleidiAI//kai/ukernels/matmul:rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", + "@KleidiAI//kai/ukernels/matmul:rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", "@KleidiAI//kai/ukernels/matmul:rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", + "@KleidiAI//kai/ukernels/matmul:rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ]), ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 00147e0d6de..50516a85fb9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -429,6 +429,7 @@ SET(OPERATOR_SRCS src/operators/global-average-pooling-nwc.c src/operators/lut-elementwise-nc.c src/operators/max-pooling-nhwc.c + src/operators/pack-lh.c src/operators/reduce-nd.c src/operators/resize-bilinear-nchw.c src/operators/resize-bilinear-nhwc.c @@ -474,6 +475,7 @@ SET(SUBGRAPH_SRCS src/subgraph/log.c src/subgraph/max-pooling-2d.c src/subgraph/negate.c + src/subgraph/pack-lh.c src/subgraph/reciprocal-square-root.c src/subgraph/reshape-helpers.c src/subgraph/scaled-dot-product-attention.c @@ -517,6 +519,7 @@ SET(XNNPACK_SRCS src/configs/lut32norm-config.c src/configs/maxpool-config.c src/configs/pavgpool-config.c + src/configs/pack-lh-config.c src/configs/raddstoreexpminusmax-config.c src/configs/reduce-config.c src/configs/rmax-config.c diff --git a/build_params.bzl b/build_params.bzl index d3deb037f5e..f50b1a6de8c 100644 --- a/build_params.bzl +++ b/build_params.bzl @@ -490,6 +490,10 @@ XNNPACK_PARAMS_FOR_ARCH = { "neonsme2": _create_params( cond = "//:arm_sme2_enabled", copts = ["-march=armv8.2-a+sve+sve2"], + extra_deps = xnnpack_if_kleidiai_enabled([ + "@KleidiAI//kai/ukernels/matmul:lhs_pack_f32p2vlx1_f32_sme", + "@KleidiAI//kai/ukernels/matmul:clamp_f32_f32p_f32p", + ]), ), "aarch32": _create_params( cond = "//build_config:aarch32", diff --git a/build_srcs.bzl b/build_srcs.bzl index 8ee26ea8578..c2e1fa7155c 100644 --- a/build_srcs.bzl +++ b/build_srcs.bzl @@ -24,6 +24,7 @@ OPERATOR_SRCS = [ "src/operators/global-average-pooling-nwc.c", "src/operators/lut-elementwise-nc.c", "src/operators/max-pooling-nhwc.c", + "src/operators/pack-lh.c", "src/operators/reduce-nd.c", "src/operators/resize-bilinear-nchw.c", "src/operators/resize-bilinear-nhwc.c", @@ -70,6 +71,7 @@ SUBGRAPH_SRCS = [ "src/subgraph/log.c", "src/subgraph/max-pooling-2d.c", "src/subgraph/negate.c", + "src/subgraph/pack-lh.c", "src/subgraph/reciprocal-square-root.c", "src/subgraph/reshape-helpers.c", "src/subgraph/rope.c", @@ -118,6 +120,7 @@ XNNPACK_SRCS = [ "src/configs/lut32norm-config.c", "src/configs/maxpool-config.c", "src/configs/pavgpool-config.c", + "src/configs/pack-lh-config.c", "src/configs/raddstoreexpminusmax-config.c", "src/configs/reduce-config.c", "src/configs/rmax-config.c", diff --git a/cmake/gen/neonsme2_microkernels.cmake b/cmake/gen/neonsme2_microkernels.cmake index 96949217900..53d3e965f18 100644 --- a/cmake/gen/neonsme2_microkernels.cmake +++ b/cmake/gen/neonsme2_microkernels.cmake @@ -9,7 +9,9 @@ # Generator: tools/update-microkernels.py -SET(PROD_NEONSME2_MICROKERNEL_SRCS) +SET(PROD_NEONSME2_MICROKERNEL_SRCS + src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c + src/x32-pack-lh/x32-packlh-neonsme2.c) SET(NON_PROD_NEONSME2_MICROKERNEL_SRCS) diff --git a/cmake/gen/sme_aarch64_microkernels.cmake b/cmake/gen/sme_aarch64_microkernels.cmake new file mode 100644 index 00000000000..7c508c29bd9 --- /dev/null +++ b/cmake/gen/sme_aarch64_microkernels.cmake @@ -0,0 +1,17 @@ +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Description: microkernel filename lists for sme_aarch64 +# +# Auto-generated file. Do not edit! +# Generator: tools/update-microkernels.py + + +SET(PROD_SME_AARCH64_MICROKERNEL_SRCS) + +SET(NON_PROD_SME_AARCH64_MICROKERNEL_SRCS + src/x32-packx-w/x32-packx-w-aarch64-sme-u2.c) + +SET(ALL_SME_AARCH64_MICROKERNEL_SRCS ${PROD_SME_AARCH64_MICROKERNEL_SRCS} + ${NON_PROD_SME_AARCH64_MICROKERNEL_SRCS}) diff --git a/cmake/gen/sme_microkernels.cmake b/cmake/gen/sme_microkernels.cmake new file mode 100644 index 00000000000..d21bab6e9e8 --- /dev/null +++ b/cmake/gen/sme_microkernels.cmake @@ -0,0 +1,16 @@ +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Description: microkernel filename lists for sme +# +# Auto-generated file. Do not edit! +# Generator: tools/update-microkernels.py + + +SET(PROD_SME_MICROKERNEL_SRCS) + +SET(NON_PROD_SME_MICROKERNEL_SRCS) + +SET(ALL_SME_MICROKERNEL_SRCS ${PROD_SME_MICROKERNEL_SRCS} + ${NON_PROD_SME_MICROKERNEL_SRCS}) diff --git a/gen/neonsme2_microkernels.bzl b/gen/neonsme2_microkernels.bzl index 0221ffe5e82..2bb071f179d 100644 --- a/gen/neonsme2_microkernels.bzl +++ b/gen/neonsme2_microkernels.bzl @@ -6,6 +6,8 @@ Auto-generated file. Do not edit! """ PROD_NEONSME2_MICROKERNEL_SRCS = [ + "src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c", + "src/x32-pack-lh/x32-packlh-neonsme2.c", ] NON_PROD_NEONSME2_MICROKERNEL_SRCS = [ diff --git a/gen/sme_aarch64_microkernels.bzl b/gen/sme_aarch64_microkernels.bzl new file mode 100644 index 00000000000..62c316adc5d --- /dev/null +++ b/gen/sme_aarch64_microkernels.bzl @@ -0,0 +1,15 @@ +""" +Microkernel filenames lists for sme_aarch64. + +Auto-generated file. Do not edit! + Generator: tools/update-microkernels.py +""" + +PROD_SME_AARCH64_MICROKERNEL_SRCS = [ +] + +NON_PROD_SME_AARCH64_MICROKERNEL_SRCS = [ + "src/x32-packx-w/x32-packx-w-aarch64-sme-u2.c", +] + +ALL_SME_AARCH64_MICROKERNEL_SRCS = PROD_SME_AARCH64_MICROKERNEL_SRCS + NON_PROD_SME_AARCH64_MICROKERNEL_SRCS diff --git a/gen/sme_microkernels.bzl b/gen/sme_microkernels.bzl new file mode 100644 index 00000000000..057966539e7 --- /dev/null +++ b/gen/sme_microkernels.bzl @@ -0,0 +1,14 @@ +""" +Microkernel filenames lists for sme. + +Auto-generated file. Do not edit! + Generator: tools/update-microkernels.py +""" + +PROD_SME_MICROKERNEL_SRCS = [ +] + +NON_PROD_SME_MICROKERNEL_SRCS = [ +] + +ALL_SME_MICROKERNEL_SRCS = PROD_SME_MICROKERNEL_SRCS + NON_PROD_SME_MICROKERNEL_SRCS diff --git a/include/xnnpack.h b/include/xnnpack.h index 801cbcea64c..a08ac9b09cd 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -281,6 +281,8 @@ enum xnn_datatype { /// Quantized 4-bit signed integer with shared per-channel-block quantization /// parameters. xnn_datatype_qbint4 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, }; /// Define a tensor-type Value and add it to a Subgraph. diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 4ccae0f55f8..1d30c86614c 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -33,6 +33,7 @@ static struct xnn_gemm_config f32_gemm_config = {0}; static struct xnn_gemm_config f32_gemm_nr2_config = {0}; static struct xnn_gemm_config f32_qc4w_gemm_config = {0}; static struct xnn_gemm_config f32_qc8w_gemm_config = {0}; +static struct xnn_gemm_config pf32_gemm_config = {0}; static struct xnn_gemm_config qd8_f16_qb4w_gemm_config = {0}; static struct xnn_gemm_config qd8_f16_qc4w_gemm_config = {0}; static struct xnn_gemm_config qd8_f16_qc8w_gemm_config = {0}; @@ -49,6 +50,7 @@ XNN_INIT_ONCE_GUARD(f32_gemm); XNN_INIT_ONCE_GUARD(f32_gemm_nr2); XNN_INIT_ONCE_GUARD(f32_qc4w_gemm); XNN_INIT_ONCE_GUARD(f32_qc8w_gemm); +XNN_INIT_ONCE_GUARD(pf32_gemm); XNN_INIT_ONCE_GUARD(qd8_f16_qb4w_gemm); XNN_INIT_ONCE_GUARD(qd8_f16_qc4w_gemm); XNN_INIT_ONCE_GUARD(qd8_f16_qc8w_gemm); @@ -241,6 +243,27 @@ static void init_f16_gemm_config(void) { const int kCoreCountThresholdForAdaptiveAvxOptimization = 4; #endif +static void init_pf32_gemm_config(void) { +#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI + const struct xnn_hardware_config* hardware_config = + xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (XNN_ENABLE_ARM_SME2 && hardware_config->use_arm_sme2) { + #if XNN_ENABLE_ARM_SME2 + const size_t mr = xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2_get_mr(); + const size_t nr = xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2_get_nr(); + pf32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(nr)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2); + pf32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + pf32_gemm_config.pack_weights_and_biases = xnn_pack_kai_f32_weights_and_biases; + pf32_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_f32_weights_and_biases; + pf32_gemm_config.mr = mr; + pf32_gemm_config.mr_packed = mr; + pf32_gemm_config.nr = nr; + #endif // XNN_ENABLE_ARM_SME2 + } +#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI +} + static void init_f32_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -399,232 +422,232 @@ static void init_f32_gemm_config(void) { } #elif XNN_ARCH_ARM64 #if XNN_ENABLE_ASSEMBLY && !XNN_PLATFORM_IOS && !XNN_PLATFORM_MAC - switch (cpuinfo_get_core(0)->uarch) { - case cpuinfo_uarch_cortex_a72: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_cortex_a57: - case cpuinfo_uarch_cortex_a75: - case cpuinfo_uarch_cortex_a76: - case cpuinfo_uarch_exynos_m3: - case cpuinfo_uarch_exynos_m4: - case cpuinfo_uarch_neoverse_n1: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); - #if XNN_ENABLE_GEMM_M_SPECIALIZATION + switch (cpuinfo_get_core(0)->uarch) { + case cpuinfo_uarch_cortex_a72: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); - #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_exynos_m1: - case cpuinfo_uarch_exynos_m2: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8s4__neonfma); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8s4__neonfma); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8s4__neonfma); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8s4__neonfma); - #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8s4__neonfma); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8s4__neonfma); - #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8s4__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - f32_gemm_config.log2_sr = 2; - break; - case cpuinfo_uarch_cortex_a53: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); - #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); - #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_cortex_a55r0: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); - #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); - #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_cortex_a35: - case cpuinfo_uarch_cortex_a55: - case cpuinfo_uarch_kryo: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); - #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); - #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_cortex_a73: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_cortex_a77: - case cpuinfo_uarch_exynos_m5: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_cortex_x3: - case cpuinfo_uarch_neoverse_v2: - // TODO(fbarchard): Implement asm with indexed inputs - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - break; - case cpuinfo_uarch_cortex_a78: - case cpuinfo_uarch_cortex_a510: - case cpuinfo_uarch_cortex_a710: - case cpuinfo_uarch_cortex_a715: - case cpuinfo_uarch_cortex_x1: - case cpuinfo_uarch_cortex_x2: - case cpuinfo_uarch_neoverse_n2: - case cpuinfo_uarch_neoverse_v1: - default: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); - #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); - #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - break; - } - #if XNN_MAX_UARCH_TYPES > 1 - { - /* Choose micro-kernels for little cores according to micro-kernel specification for the big core */ - const uint32_t mr = f32_gemm_config.mr; - const uint32_t nr = f32_gemm_config.nr; - const uint32_t log2_sr = f32_gemm_config.log2_sr; - for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) { - const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i); - if (uarch_info == NULL) { - /* No more microarchitectures in the system */ + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 4; + f32_gemm_config.nr = 8; break; - } + case cpuinfo_uarch_cortex_a57: + case cpuinfo_uarch_cortex_a75: + case cpuinfo_uarch_cortex_a76: + case cpuinfo_uarch_exynos_m3: + case cpuinfo_uarch_exynos_m4: + case cpuinfo_uarch_neoverse_n1: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); + #endif + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + break; + case cpuinfo_uarch_exynos_m1: + case cpuinfo_uarch_exynos_m2: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8s4__neonfma); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8s4__neonfma); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8s4__neonfma); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8s4__neonfma); + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8s4__neonfma); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8s4__neonfma); + #endif + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8s4__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + f32_gemm_config.log2_sr = 2; + break; + case cpuinfo_uarch_cortex_a53: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); + #endif + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + break; + case cpuinfo_uarch_cortex_a55r0: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); + #endif + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + break; + case cpuinfo_uarch_cortex_a35: + case cpuinfo_uarch_cortex_a55: + case cpuinfo_uarch_kryo: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); + #endif + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + break; + case cpuinfo_uarch_cortex_a73: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + break; + case cpuinfo_uarch_cortex_a77: + case cpuinfo_uarch_exynos_m5: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 4; + f32_gemm_config.nr = 8; + break; + case cpuinfo_uarch_cortex_x3: + case cpuinfo_uarch_neoverse_v2: + // TODO(fbarchard): Implement asm with indexed inputs + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + break; + case cpuinfo_uarch_cortex_a78: + case cpuinfo_uarch_cortex_a510: + case cpuinfo_uarch_cortex_a710: + case cpuinfo_uarch_cortex_a715: + case cpuinfo_uarch_cortex_x1: + case cpuinfo_uarch_cortex_x2: + case cpuinfo_uarch_neoverse_n2: + case cpuinfo_uarch_neoverse_v1: + default: + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); + #endif + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_gemm_config.mr = 6; + f32_gemm_config.nr = 8; + break; + } + #if XNN_MAX_UARCH_TYPES > 1 + { + /* Choose micro-kernels for little cores according to micro-kernel specification for the big core */ + const uint32_t mr = f32_gemm_config.mr; + const uint32_t nr = f32_gemm_config.nr; + const uint32_t log2_sr = f32_gemm_config.log2_sr; + for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) { + const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i); + if (uarch_info == NULL) { + /* No more microarchitectures in the system */ + break; + } - switch (uarch_info->uarch) { - case cpuinfo_uarch_cortex_a53: - if (mr == 6 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; - #if XNN_ENABLE_GEMM_M_SPECIALIZATION + switch (uarch_info->uarch) { + case cpuinfo_uarch_cortex_a53: + if (mr == 6 && nr == 8 && log2_sr == 0) { + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; + #endif + } else if (mr == 4 && nr == 8 && log2_sr == 0) { + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; - #endif - } else if (mr == 4 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; - } - break; - case cpuinfo_uarch_cortex_a55r0: - if (mr == 6 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; - #if XNN_ENABLE_GEMM_M_SPECIALIZATION + } + break; + case cpuinfo_uarch_cortex_a55r0: + if (mr == 6 && nr == 8 && log2_sr == 0) { + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; + #endif + } else if (mr == 4 && nr == 8 && log2_sr == 0) { + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; - #endif - } else if (mr == 4 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; - } - break; - case cpuinfo_uarch_cortex_a55: - if (mr == 6 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; - #if XNN_ENABLE_GEMM_M_SPECIALIZATION + } + break; + case cpuinfo_uarch_cortex_a55: + if (mr == 6 && nr == 8 && log2_sr == 0) { + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; + #if XNN_ENABLE_GEMM_M_SPECIALIZATION + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; + #endif + } else if (mr == 4 && nr == 8 && log2_sr == 0) { + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; + f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; - #endif - } else if (mr == 4 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; - } - break; - default: - break; + } + break; + default: + break; + } } } - } #endif // XNN_MAX_UARCH_TYPES > 1 #else // XNN_ENABLE_ASSEMBLY && !XNN_PLATFORM_IOS && !XNN_PLATFORM_MAC #if XNN_ENABLE_ASSEMBLY @@ -3850,6 +3873,15 @@ const struct xnn_gemm_config* xnn_init_f16_gemm_config() { return &f16_gemm_config; } +const struct xnn_gemm_config* xnn_init_pf32_gemm_config() { + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + if (hardware_config == NULL) { + return NULL; + } + XNN_INIT_ONCE(pf32_gemm); + return &pf32_gemm_config; +} + const struct xnn_gemm_config* xnn_init_f32_gemm_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL) { diff --git a/src/configs/pack-lh-config.c b/src/configs/pack-lh-config.c new file mode 100644 index 00000000000..fa8362c7ae2 --- /dev/null +++ b/src/configs/pack-lh-config.c @@ -0,0 +1,41 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "xnnpack/common.h" +#include "xnnpack/config.h" +#include "xnnpack/hardware-config.h" +#include "xnnpack/init-once.h" +#include "xnnpack/microfnptr.h" +#include "xnnpack/pack-lh.h" + +static struct xnn_pack_lh_config x32_pack_lh_config = {0}; + +XNN_INIT_ONCE_GUARD(x32_pack_lh); + +static void init_x32_pack_lh_config(void) { +#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI + #if XNN_ENABLE_ARM_SME2 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (hardware_config->use_arm_sme2) { + x32_pack_lh_config.ukernel = (xnn_x32_pack_lh_ukernel_fn) xnn_x32_pack_lh_ukernel__neonsme2; + x32_pack_lh_config.size_fn = (xnn_x32_pack_lh_size_fn) xnn_x32_pack_lh_size__neonsme2; + } + #endif // XNN_ENABLE_ARM_SME2 +#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI +} + +const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config() { + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + if (hardware_config == NULL) { + return NULL; + } + XNN_INIT_ONCE(x32_pack_lh); + return &x32_pack_lh_config; +} diff --git a/src/enums/datatype-strings.c b/src/enums/datatype-strings.c index 7f54f334f44..bf30c63afde 100644 --- a/src/enums/datatype-strings.c +++ b/src/enums/datatype-strings.c @@ -20,6 +20,8 @@ const char* xnn_datatype_to_string(enum xnn_datatype type) { return "FP32"; case xnn_datatype_fp16: return "FP16"; + case xnn_datatype_pfp32: + return "PFP32"; case xnn_datatype_qint8: return "QINT8"; case xnn_datatype_quint8: diff --git a/src/operator-run.c b/src/operator-run.c index 9bec626c35b..bab0e960a7a 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -2333,6 +2333,19 @@ void xnn_compute_f32_qd8_convert( context->convert_ukernel(n, input, output, ¶ms); } +void xnn_compute_x32_pack_lh( + const struct x32_pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t m_idx_start, size_t tile) { + const float* lhs = (const float*)((const char*)context->lhs + + m_idx_start * context->lhs_stride); + const size_t offset = context->k * m_idx_start; + float* lhs_packed = context->lhs_packed + offset; + + context->pack_lh_ukernel(/*m=*/tile, context->k, context->mr, context->kr, + context->sr, 0, (const uint32_t*) lhs, context->lhs_stride, + (uint32_t*) lhs_packed); +} + void xnn_compute_f32_qp8_convert( const struct f32_qp8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], size_t m_idx_start) { diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 169cd0c7a8c..2914e38dd19 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -1280,7 +1280,7 @@ enum xnn_status xnn_create_fully_connected_nc_f32_f16( return status; } -enum xnn_status xnn_create_fully_connected_nc_f32( +enum xnn_status create_fully_connected_nc_f32( size_t input_channels, size_t output_channels, size_t input_stride, @@ -1292,6 +1292,7 @@ enum xnn_status xnn_create_fully_connected_nc_f32( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config* gemm_config, xnn_operator_t* fully_connected_op_out) { if (isnan(output_min)) { @@ -1315,21 +1316,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32( return xnn_status_invalid_parameter; } - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); - if (gemm_config == NULL) { - xnn_log_error("failed to create %s operator: unsupported hardware configuration", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32)); - return xnn_status_unsupported_hardware; - } - - const struct xnn_gemm_config* gemm_nr2_config = xnn_init_f32_gemm_nr2_config(); - if (gemm_config->nr > output_channels) { - // Default microkernel is suboptimal, use a microkernel that better supports less output channels. - if (gemm_nr2_config != NULL && gemm_nr2_config->minmax.gemm[gemm_nr2_config->mr-1].function[XNN_UARCH_DEFAULT] != NULL) { - gemm_config = gemm_nr2_config; - } - } - const struct gemm_fused_ukernels* gemm_ukernels = &gemm_config->minmax; const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max); if (linear_activation && gemm_config->linear.gemm[gemm_config->mr-1].function[XNN_UARCH_DEFAULT] != NULL) { @@ -1367,6 +1353,60 @@ enum xnn_status xnn_create_fully_connected_nc_f32( fully_connected_op_out); } +enum xnn_status xnn_create_fully_connected_nc_f32( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) { + const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + if (gemm_config == NULL) { + xnn_log_error("failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32)); + return xnn_status_unsupported_hardware; + } + + const struct xnn_gemm_config* gemm_nr2_config = xnn_init_f32_gemm_nr2_config(); + if (gemm_config->nr > output_channels) { + // Default microkernel is suboptimal, use a microkernel that better supports less output channels. + if (gemm_nr2_config != NULL && gemm_nr2_config->minmax.gemm[gemm_nr2_config->mr-1].function[XNN_UARCH_DEFAULT] != NULL) { + gemm_config = gemm_nr2_config; + } + } + + return create_fully_connected_nc_f32(input_channels, output_channels, input_stride, output_stride, kernel, bias, output_min, output_max, flags, code_cache, weights_cache, gemm_config, fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_pf32( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) { + const struct xnn_gemm_config* gemm_config = xnn_init_pf32_gemm_config(); + if (gemm_config == NULL) { + xnn_log_error("failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_pf32)); + return xnn_status_unsupported_hardware; + } + + return create_fully_connected_nc_f32(input_channels, output_channels, input_stride, output_stride, kernel, bias, output_min, output_max, flags, code_cache, weights_cache, gemm_config, fully_connected_op_out); +} + enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( size_t input_channels, size_t output_channels, diff --git a/src/operators/pack-lh.c b/src/operators/pack-lh.c new file mode 100644 index 00000000000..eb576875a77 --- /dev/null +++ b/src/operators/pack-lh.c @@ -0,0 +1,140 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack.h" +#include "xnnpack/allocator.h" +#include "xnnpack/config-types.h" +#include "xnnpack/config.h" +#include "xnnpack/log.h" +#include "xnnpack/operator-type.h" +#include "xnnpack/operator.h" +#include "xnnpack/params.h" + +enum xnn_status xnn_create_pack_lh_x32( + uint32_t flags, + xnn_operator_t* pack_lh_op_out) +{ + const struct xnn_pack_lh_config *pack_lh_config = xnn_init_x32_pack_lh_config(); + xnn_operator_t pack_lh_op = NULL; + + if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { + xnn_log_error("failed to create %s operator: XNNPACK is not initialized", + xnn_operator_type_to_string(xnn_operator_type_pack_lh_x32)); + return xnn_status_uninitialized; + } + + if (pack_lh_config == NULL) { + xnn_log_error( + "failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_pack_lh_x32)); + return xnn_status_unsupported_hardware; + } + + pack_lh_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator)); + if (pack_lh_op == NULL) { + xnn_log_error( + "failed to allocate %zu bytes for %s operator descriptor", + sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_pack_lh_x32)); + return xnn_status_out_of_memory; + } + + pack_lh_op->pack_lh_config = pack_lh_config; + pack_lh_op->type = xnn_operator_type_pack_lh_x32; + pack_lh_op->flags = flags; + pack_lh_op->state = xnn_run_state_invalid; + + *pack_lh_op_out = pack_lh_op; + return xnn_status_success; +} + +enum xnn_status xnn_reshape_pack_lh_x32( + xnn_operator_t pack_lh_op, + size_t batch_size, + size_t channels, + size_t *output_size_bytes, + pthreadpool_t threadpool) +{ + if (pack_lh_op->type != xnn_operator_type_pack_lh_x32) { + xnn_log_error( + "failed to reshape operator: operator type mismatch (expected %s, got " + "%s)", + xnn_operator_type_to_string(xnn_operator_type_pack_lh_x32), + xnn_operator_type_to_string(pack_lh_op->type)); + return xnn_status_invalid_parameter; + } + pack_lh_op->state = xnn_run_state_invalid; + + if (batch_size == 0) { + pack_lh_op->state = xnn_run_state_skip; + return xnn_status_success; + } + + pack_lh_op->batch_size = batch_size; + + const struct xnn_pack_lh_config *pack_lh_config = xnn_init_x32_pack_lh_config(); + const struct xnn_gemm_config* gemm_config = + xnn_init_pf32_gemm_config(); + const uint32_t mr_packed = batch_size == 1 ? 1 : gemm_config->mr_packed; + const size_t mr = gemm_config->mr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + + pack_lh_op->context.x32_pack_lh = (struct x32_pack_lh_context) { + .m = batch_size, + .k = channels, + .mr = mr,//mr_packed, + .kr = kr, + .sr = sr, + .lhs_stride = channels * sizeof(float), + .pack_lh_ukernel = (xnn_x32_pack_lh_ukernel_fn) + pack_lh_op->pack_lh_config->ukernel, + }; + + *output_size_bytes = pack_lh_config->size_fn(batch_size, channels, mr, kr, sr); + pack_lh_op->compute[0].type = xnn_parallelization_type_1d_tile_1d; + pack_lh_op->compute[0].task_1d = + (pthreadpool_task_1d_t) xnn_compute_x32_pack_lh; + pack_lh_op->compute[0].range[0] = batch_size; + pack_lh_op->compute[0].tile[0] = mr_packed; + + pack_lh_op->state = xnn_run_state_needs_setup; + + return xnn_status_success; +} + +enum xnn_status xnn_setup_pack_lh_x32( + xnn_operator_t pack_lh_op, + const void* input, + void* output) +{ + if (pack_lh_op->type != xnn_operator_type_pack_lh_x32) { + xnn_log_error( + "failed to setup operator: operator type mismatch (expected %s, got " + "%s)", + xnn_operator_type_to_string(xnn_operator_type_pack_lh_x32), + xnn_operator_type_to_string(pack_lh_op->type)); + return xnn_status_invalid_parameter; + } + switch (pack_lh_op->state) { + case xnn_run_state_skip: + return xnn_status_success; + case xnn_run_state_invalid: + xnn_log_error( + "failed to setup %s operator: operator has not been reshaped yet", + xnn_operator_type_to_string(pack_lh_op->type)); + return xnn_status_invalid_state; + case xnn_run_state_needs_setup: + // Operator has been reshaped, but not setup, continue with setup. + case xnn_run_state_ready: + // Operator has been reshaped, and we are setting up with different + // pointers. + break; + } + + pack_lh_op->context.x32_pack_lh.lhs = input; + pack_lh_op->context.x32_pack_lh.lhs_packed = output; + pack_lh_op->state = xnn_run_state_ready; + return xnn_status_success; +} diff --git a/src/packing.cc b/src/packing.cc index 19cf4e357af..1bd7912db06 100644 --- a/src/packing.cc +++ b/src/packing.cc @@ -23,10 +23,11 @@ #include "xnnpack/unaligned.h" #if XNN_ENABLE_KLEIDIAI + #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" + #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" - #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" - #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" + #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" #endif // XNN_ENABLE_KLEIDIAI @@ -1555,6 +1556,73 @@ void xnn_pack_kai_qs4_weights_and_biases( } } +size_t xnn_packed_stride_kai_f32_weights_and_biases( + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + size_t extra_bytes) { + size_t ret_val = + kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(k) / + kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(); + return ret_val; +} + +void transpose_weights(const float* in, float* out, size_t height, size_t width) { + for (size_t i = 0; i < height; ++i) { + for (size_t j = 0; j < width; ++j) { + out[j * height + i] = in[i * width + j]; + } + } +} + +void xnn_pack_kai_f32_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + assert(extra_data0 == nullptr); + assert(extra_data1 == nullptr); + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t rhs_stride = output_channels * sizeof(float); + + // Some packing kernels assume that the bias is non-null. Allocate a zero + // initialized array as a workaround if bias is null. + bool free_accumulator_init = false; + if (accumulator_init == NULL) { + accumulator_init = calloc(output_channels, sizeof(float)); + free_accumulator_init = true; + } + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + groups, output_channels, input_channels, nr, kr, sr, rhs_stride, + /*rhs=*/reinterpret_cast(weights), + /*bias=*/reinterpret_cast(accumulator_init), + /*scale=*/reinterpret_cast(extra_data1), + /*rhs_packed=*/packed_weights_ptr, + /*extra_bytes=*/extra_data0_element_size + extra_data1_element_size, NULL); + } else { + // Transpose the weights until the transpose packing function is ready. + float* tmp_data = + (float*) malloc(input_channels * output_channels * sizeof(float)); + transpose_weights((const float*) weights, tmp_data, output_channels, input_channels); + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + groups, output_channels, input_channels, nr, kr, sr, rhs_stride, + /*rhs=*/reinterpret_cast(tmp_data), + /*bias=*/reinterpret_cast(accumulator_init), + /*scale=*/reinterpret_cast(extra_data1), + /*rhs_packed=*/packed_weights_ptr, + /*extra_bytes=*/extra_data0_element_size + extra_data1_element_size, NULL); + free(tmp_data); + } + if (free_accumulator_init) { + free((void*) accumulator_init); + } +} + size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, size_t extra_bytes) { @@ -1565,13 +1633,9 @@ size_t xnn_packed_stride_kai_qb4_weights_and_biases( // We want the weight stride with nr = 1, but kleidi enforces a constraint // where nr % 4 == 0. So instead we give nr to get the nr-scaled stride, and // divide by nr to scaled down the stride. - const size_t nr_scaled_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - k, - nr, - kr, - sr, - block_size, - kai_datatype::kai_dt_bf16); + const size_t nr_scaled_packed_stride = + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, nr, kr, sr, block_size, kai_datatype::kai_dt_bf16); return nr_scaled_packed_stride / nr; } diff --git a/src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c b/src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c new file mode 100644 index 00000000000..a0c5666658b --- /dev/null +++ b/src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c @@ -0,0 +1,41 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI + // Keep this line indented to avoid it being pulled out of the #ifdef when the + // sources are amalgamated. + #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#endif // XNN_ENABLE_KLEIDIAI + +size_t xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2_get_mr() { + return kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); +} + +size_t xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2_get_nr() { + return kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); +} + +// Wraps the `kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa` +// GEMM microkernel with a name that is compatible with our tooling. +void xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2( + size_t m, size_t n, size_t k, const void* lhs_packed, size_t lhs_stride, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + union xnn_f32_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + m, n, k / sizeof(float), lhs_packed, rhs_packed, dst, dst_stride_row, /*dst_stride_col=*/sizeof(float), + minmax_params->scalar.min, minmax_params->scalar.max); +#else + assert( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`." && 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/subgraph.c b/src/subgraph.c index cc26fbf4d0c..6e8735ac47e 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -19,6 +19,7 @@ #include "xnnpack/common.h" #include "xnnpack/fp16.h" #include "xnnpack/hardware-config.h" +#include "xnnpack/internal.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" @@ -73,6 +74,24 @@ enum xnn_status xnn_insert_clamp_node(xnn_subgraph_t subgraph, float output_min, return xnn_define_clamp(subgraph, output_min, output_max, new_id, output_id, /*flags=*/0); } +enum xnn_status xnn_insert_pack_lh_node(xnn_subgraph_t subgraph, const struct xnn_value* input, uint32_t input_id, uint32_t *new_id) { + enum xnn_status status; + switch (input->datatype) { + case xnn_datatype_fp32: + status = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, 0, NULL, NULL, + /*external_id=*/XNN_INVALID_VALUE_ID, /*flags=*/0, new_id); + break; + default: + XNN_UNREACHABLE; + } + if (status != xnn_status_success) { + return status; + } + + return xnn_define_pack_lh(subgraph, input_id, *new_id, /*flags=*/0); +} + enum xnn_status xnn_create_subgraph( uint32_t external_value_ids, uint32_t flags, diff --git a/src/subgraph/fully-connected.c b/src/subgraph/fully-connected.c index 2a6459b818b..89e6c40ba69 100644 --- a/src/subgraph/fully-connected.c +++ b/src/subgraph/fully-connected.c @@ -11,11 +11,13 @@ #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/common.h" +#include "xnnpack/config.h" #include "xnnpack/internal.h" #include "xnnpack/log.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator.h" +#include "xnnpack/subgraph.h" #include "xnnpack/requantization.h" #include "xnnpack/subgraph-validation.h" #include "xnnpack/subgraph.h" @@ -43,6 +45,7 @@ enum fully_connected_op_type { fc_type_qs8_qs8_qs8 = 17, fc_type_qu8_qu8_qu8 = 18, fc_type_qp8_f32_qb4w = 19, + fc_type_pf32_f32_f32 = 20, }; enum fully_connected_op_type get_fully_connected_op_type( @@ -93,7 +96,14 @@ enum fully_connected_op_type get_fully_connected_op_type( if (has_non_static_weights) { return fc_type_f32_f32_f32_dynamic; } else { - return fc_type_f32_f32_f32; + switch (input_datatype) { + case xnn_datatype_fp32: + return fc_type_f32_f32_f32; + case xnn_datatype_pfp32: + return fc_type_pf32_f32_f32; + default: + XNN_UNREACHABLE; + } } case xnn_datatype_qbint4: switch (input_datatype) { @@ -269,6 +279,15 @@ static enum xnn_status create_fully_connected_operator( /*flags=*/node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_pf32_f32_f32: + status = xnn_create_fully_connected_nc_pf32( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, kernel_data, bias_data, + node->activation.output_min, node->activation.output_max, + /*flags=*/node->flags, code_cache, weights_cache, + &opdata->operator_objects[0]); + break; case fc_type_qd8_f32_qb4w: status = xnn_create_fully_connected_nc_qd8_f32_qb4w( input_channels, output_channels, @@ -1226,6 +1245,18 @@ enum xnn_status xnn_define_fully_connected(xnn_subgraph_t subgraph, } } + if (compute_type == xnn_compute_type_fp32) { + const struct xnn_gemm_config* gemm_config = xnn_init_pf32_gemm_config(); + if (gemm_config != NULL && gemm_config->init.f32 != NULL) { + // Insert a node to pack the LHS. + uint32_t new_id = XNN_INVALID_VALUE_ID; + status = xnn_insert_pack_lh_node(subgraph, input_value, input_id, &new_id); + if (status != xnn_status_success) { + return status; + } + input_id = new_id; + } + } struct xnn_node* node = xnn_subgraph_new_node(subgraph); if (node == NULL) { return xnn_status_out_of_memory; diff --git a/src/subgraph/pack-lh.c b/src/subgraph/pack-lh.c new file mode 100644 index 00000000000..63dc8e070b3 --- /dev/null +++ b/src/subgraph/pack-lh.c @@ -0,0 +1,210 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include "xnnpack.h" +#include "xnnpack/common.h" +#include "xnnpack/internal.h" +#include "xnnpack/log.h" +#include "xnnpack/node-type.h" +#include "xnnpack/operator-type.h" +#include "xnnpack/operator.h" +#include "xnnpack/reshape-helpers.h" +#include "xnnpack/subgraph-validation.h" +#include "xnnpack/subgraph.h" +#include "pthreadpool.h" + +static enum xnn_status create_pack_lh_operator( + const struct xnn_node* node, + const struct xnn_value* values, + size_t num_values, + struct xnn_operator_data* opdata, + struct xnn_code_cache* code_cache, + xnn_weights_cache_t weights_cache) +{ + assert(node->num_inputs == 1); + + assert(node->num_outputs == 1); + + const uint32_t input_id = node->inputs[0]; + assert(input_id < num_values); + const struct xnn_value *input_value = &values[input_id]; + enum xnn_status status; + switch (input_value->datatype) { + case xnn_datatype_fp32: + status = xnn_create_pack_lh_x32( + node->flags, + &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; + } + return status; +} + +static enum xnn_status reshape_pack_lh_operator( + struct xnn_operator_data* opdata, + struct xnn_value* values, + size_t num_values, + pthreadpool_t threadpool) +{ + const uint32_t input_id = opdata->inputs[0]; + assert(input_id < num_values); + const struct xnn_value* input_value = &values[input_id]; + const uint32_t output_id = opdata->outputs[0]; + assert(output_id < num_values); + struct xnn_value* output_value = &values[output_id]; + const size_t batch_size = xnn_shape_multiply_non_channel_dims(&input_value->shape); + const size_t num_input_dims = input_value->shape.num_dims; + const size_t channels = num_input_dims == 0 ? 1 : input_value->shape.dim[num_input_dims - 1]; + const size_t old_workspace_size = opdata->workspace_size; + enum xnn_status status = xnn_status_invalid_state; + size_t output_size_bytes = 0; + + switch (opdata->operator_objects[0]->type) { + case xnn_operator_type_pack_lh_x32: + status = xnn_reshape_pack_lh_x32( + opdata->operator_objects[0], + batch_size, + channels, + &output_size_bytes, + threadpool); + break; + default: + XNN_UNREACHABLE; + } + if (status != xnn_status_success) { + return status; + } + // output shape is identical to input shape, however the data is packed into + // an appropriate format for the following operation so the number of bytes + // required cannot be determined from the shape alone. + output_value->shape.num_dims = num_input_dims; + memcpy(&output_value->shape.dim[0], &input_value->shape.dim[0], num_input_dims * sizeof(size_t)); + if (output_size_bytes > output_value->size || opdata->workspace_size > old_workspace_size) { + output_value->size = output_size_bytes; + return xnn_status_reallocation_required; + } + return xnn_status_success; +} + +static enum xnn_status setup_pack_lh_operator( + const struct xnn_operator_data* opdata, + const struct xnn_value* values, + size_t num_values, + pthreadpool_t threadpool) +{ + const uint32_t input_id = opdata->inputs[0]; + assert(input_id != XNN_INVALID_VALUE_ID); + assert(input_id < num_values); + + const uint32_t output_id = opdata->outputs[0]; + assert(output_id != XNN_INVALID_VALUE_ID); + assert(output_id < num_values); + + const struct xnn_value* input_value = values + input_id; + const void* input_data = input_value->data; + assert(input_data != NULL); + + const struct xnn_value* output_value = values + output_id; + void* output_data = output_value->data; + assert(output_data != NULL); + + switch (opdata->operator_objects[0]->type) { + case xnn_operator_type_pack_lh_x32: + return xnn_setup_pack_lh_x32( + opdata->operator_objects[0], + input_data, + output_data); + default: + XNN_UNREACHABLE; + } +} + +enum xnn_status xnn_define_pack_lh( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags) +{ + enum xnn_status status; + if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_pack_lh)) != xnn_status_success) { + return status; + } + + if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_pack_lh, input_id, subgraph->num_values)) != + xnn_status_success) { + return status; + } + + const struct xnn_value* input_value = &subgraph->values[input_id]; + status = xnn_subgraph_check_input_type_dense(xnn_node_type_pack_lh, input_id, input_value); + if (status != xnn_status_success) { + return status; + } + + switch (input_value->datatype) { + case xnn_datatype_fp32: + break; + default: + xnn_log_error( + "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)", + xnn_node_type_to_string(xnn_node_type_pack_lh), input_id, + xnn_datatype_to_string(input_value->datatype), input_value->datatype); + return xnn_status_invalid_parameter; + } + + status = xnn_subgraph_check_output_node_id(xnn_node_type_pack_lh, output_id, subgraph->num_values); + if (status != xnn_status_success) { + return status; + } + + struct xnn_value* output_value = &subgraph->values[output_id]; + status = xnn_subgraph_check_output_type_dense(xnn_node_type_pack_lh, output_id, output_value); + if (status != xnn_status_success) { + return status; + } + + enum xnn_compute_type compute_type = xnn_compute_type_invalid; + switch (output_value->datatype) { + case xnn_datatype_fp32: + compute_type = xnn_compute_type_fp32; + // Coerce the output from `xnn_datatype_fp32` to `xnn_datatype_pfp32` so + // that the correct GEMM path is taken. + output_value->datatype = xnn_datatype_pfp32; + break; + default: + xnn_log_error( + "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)", + xnn_node_type_to_string(xnn_node_type_pack_lh), output_id, + xnn_datatype_to_string(output_value->datatype), output_value->datatype); + return xnn_status_invalid_parameter; + } + assert(compute_type != xnn_compute_type_invalid); + + struct xnn_node* node = xnn_subgraph_new_node(subgraph); + if (node == NULL) { + return xnn_status_out_of_memory; + } + + node->type = xnn_node_type_pack_lh; + node->compute_type = compute_type; + node->num_inputs = 1; + node->inputs[0] = input_id; + node->num_outputs = 1; + node->outputs[0] = output_id; + node->flags = flags; + + node->create = create_pack_lh_operator; + node->reshape = reshape_pack_lh_operator; + node->setup = setup_pack_lh_operator; + + return xnn_status_success; +} diff --git a/src/tensor.c b/src/tensor.c index 447c6faeb6c..6e30deeab22 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -126,6 +126,7 @@ enum xnn_status xnn_define_tensor_value( case xnn_datatype_fp32: case xnn_datatype_fp16: case xnn_datatype_int32: + case xnn_datatype_pfp32: break; default: xnn_log_error("failed to create Dense Tensor value: unsupported datatype %s (%d)", @@ -622,6 +623,7 @@ size_t xnn_tensor_get_size(const struct xnn_value* value) size = 2; break; case xnn_datatype_fp32: + case xnn_datatype_pfp32: size = 4; break; case xnn_datatype_qcint4: diff --git a/src/x32-pack-lh/x32-pack-lh.h b/src/x32-pack-lh/x32-pack-lh.h new file mode 100644 index 00000000000..6af063482c1 --- /dev/null +++ b/src/x32-pack-lh/x32-pack-lh.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef XNN_UKERNEL_WITH_PARAMS +#define XNN_UKERNEL_WITH_PARAMS(arch_flags, ukernel, unroll, params_type, \ + init_params) \ + XNN_UKERNEL(arch_flags, ukernel, unroll) +#define XNN_DEFINED_UKERNEL_WITH_PARAMS +#endif + +#ifndef XNN_UKERNEL +#define XNN_UKERNEL(arch_flags, ukernel, unroll) \ + XNN_UKERNEL_WITH_PARAMS(arch_flags, ukernel, unroll) +#define XNN_DEFINED_UKERNEL +#endif + +// arch_flags, ukernel, unroll + +#if XNN_ENABLE_KLEIDIAI +XNN_UKERNEL(xnn_arch_arm_sme, xnn_x32_pack_lh_ukernel__neonsme2, + xnn_x32_pack_lh_size__neonsme2) +#endif // XNN_ENABLE_KLEIDIAI + +#ifdef XNN_DEFINED_UKERNEL_WITH_PARAMS +#undef XNN_DEFINED_UKERNEL_WITH_PARAMS +#undef XNN_UKERNEL_WITH_PARAMS +#endif + +#ifdef XNN_DEFINED_UKERNEL +#undef XNN_DEFINED_UKERNEL +#undef XNN_UKERNEL +#endif diff --git a/src/x32-pack-lh/x32-packlh-neonsme2.c b/src/x32-pack-lh/x32-packlh-neonsme2.c new file mode 100644 index 00000000000..999894e798c --- /dev/null +++ b/src/x32-pack-lh/x32-packlh-neonsme2.c @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include + +#include "xnnpack/common.h" +#include "xnnpack/math.h" +#include "xnnpack/pack-lh.h" + +#if XNN_ENABLE_KLEIDIAI + // Keep this line indented to avoid it being pulled out of the #ifdef when the + // sources are amalgamated. + #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#endif // XNN_ENABLE_KLEIDIAI + + +// This function just wraps KleidiAI's `kai_run_lhs_pack_f32p2vlx1_f32_sme`, but +// with a name that is recognized by our tooling. +void xnn_x32_pack_lh_ukernel__neonsme2(size_t m, size_t k, size_t mr, + size_t kr, size_t sr, + size_t m_idx_start, + const float* XNN_RESTRICT lhs, + size_t lhs_stride, + void* XNN_RESTRICT lhs_packed) { +#if XNN_ENABLE_KLEIDIAI + kai_run_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr, m_idx_start, lhs, + lhs_stride, lhs_packed); +#else + assert("Not compiled with XNN_ENABLE_KLEIDIAI" && 0); +#endif // XNN_ENABLE_KLEIDIAI +} + +size_t xnn_x32_pack_lh_size__neonsme2(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { +#if XNN_ENABLE_KLEIDIAI + return kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr); +#else + assert("Not compiled with XNN_ENABLE_KLEIDIAI" && 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 7a257df9034..aaa73bebcce 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -1580,17 +1580,35 @@ struct f32_qd8_convert_context { size_t batch_index); #endif -struct f32_qp8_convert_context { - size_t m; - size_t k; - size_t mr; - size_t kr; - size_t sr; - const float* XNN_RESTRICT lhs; - size_t lhs_stride; - int8_t* XNN_RESTRICT lhs_packed; - xnn_x8_packq_f32qp8_ukernel_fn packq_ukernel; -}; + struct x32_pack_lh_context { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + const float* XNN_RESTRICT lhs; + size_t lhs_stride; + float* XNN_RESTRICT lhs_packed; + xnn_x32_pack_lh_ukernel_fn pack_lh_ukernel; + }; + +#ifndef __cplusplus + XNN_PRIVATE void xnn_compute_x32_pack_lh( + const struct x32_pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t m_idx_start, size_t tile); +#endif + + struct f32_qp8_convert_context { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + const float* XNN_RESTRICT lhs; + size_t lhs_stride; + int8_t* XNN_RESTRICT lhs_packed; + xnn_x8_packq_f32qp8_ukernel_fn packq_ukernel; + }; #ifndef __cplusplus XNN_PRIVATE void xnn_compute_f32_qp8_convert( diff --git a/src/xnnpack/config-types.h b/src/xnnpack/config-types.h index 3d9e5f9b309..61b328104d6 100644 --- a/src/xnnpack/config-types.h +++ b/src/xnnpack/config-types.h @@ -145,6 +145,11 @@ struct xnn_avgpool_config { uint16_t channel_tile; }; +struct xnn_pack_lh_config { + xnn_x32_pack_lh_ukernel_fn ukernel; + xnn_x32_pack_lh_size_fn size_fn; +}; + struct xnn_pavgpool_config { xnn_pavgpool_unipass_ukernel_fn unipass; xnn_pavgpool_multipass_ukernel_fn multipass; diff --git a/src/xnnpack/config.h b/src/xnnpack/config.h index 67c7ee16bc5..abd90052087 100644 --- a/src/xnnpack/config.h +++ b/src/xnnpack/config.h @@ -25,6 +25,8 @@ XNN_INTERNAL const struct xnn_transpose_config* xnn_init_transpose_config(); XNN_INTERNAL const struct xnn_cmul_config* xnn_init_f16_cmul_config(); XNN_INTERNAL const struct xnn_cmul_config* xnn_init_f32_cmul_config(); +XNN_INTERNAL const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config(); + XNN_INTERNAL const struct xnn_binary_elementwise_config* xnn_init_f16_vadd_config(); XNN_INTERNAL const struct xnn_binary_elementwise_config* xnn_init_f16_vdiv_config(); XNN_INTERNAL const struct xnn_binary_elementwise_config* xnn_init_f16_vmax_config(); @@ -249,6 +251,7 @@ XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_gemm_nr2_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_qc4w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_pf32_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f16_qb4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f16_qc4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f16_qc8w_gemm_config(); diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 9d8ec3bf1fc..81dbf782393 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -157,6 +157,26 @@ DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x64__avx51 size_t cn_stride, \ const struct xnn_f32_relu_params params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +#define DECLARE_PF32_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ + \ + XNN_INTERNAL size_t fn_name ## _get_mr(); \ + XNN_INTERNAL size_t fn_name ## _get_nr(); \ + \ + XNN_INTERNAL void fn_name( \ + size_t mr, \ + size_t nc, \ + size_t kc, \ + const void* a, \ + size_t a_stride, \ + const float* w, \ + float* c, \ + size_t cm_stride, \ + size_t cn_stride, \ + const union xnn_f32_minmax_params params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); + +DECLARE_PF32_GEMM_MINMAX_UKERNEL_FUNCTION( + xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2) + #define DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t mr, \ diff --git a/src/xnnpack/internal.h b/src/xnnpack/internal.h index 369eaaf5b0b..848cd56ae04 100644 --- a/src/xnnpack/internal.h +++ b/src/xnnpack/internal.h @@ -61,6 +61,20 @@ enum xnn_status xnn_setup_convert_nc_f32_qp8(xnn_operator_t convert_op, // const float* input, // int8_t* output); +enum xnn_status xnn_create_pack_lh_x32(uint32_t flags, + xnn_operator_t* pack_lh_op_out); + +enum xnn_status xnn_reshape_pack_lh_x32(xnn_operator_t pack_lh_op, + size_t batch_size, size_t channels, + size_t* output_size_bytes, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_pack_lh_x32(xnn_operator_t pack_lh_op, + const void* input, void* output); + +enum xnn_status xnn_define_pack_lh(xnn_subgraph_t subgraph, uint32_t input_id, + uint32_t output_id, uint32_t flags); + enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( size_t input_channels, // size_t output_channels, // @@ -88,6 +102,13 @@ enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qb4w( size_t batch_size, // pthreadpool_t threadpool); +enum xnn_status xnn_create_fully_connected_nc_pf32( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, const float* kernel, const float* bias, + float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + #ifdef __cplusplus } // extern "C" #endif diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 1705904ee9c..1ac1c71cf8d 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -1549,6 +1549,16 @@ typedef void (*xnn_x32_packx_ukernel_fn)( size_t x_stride, uint32_t* y); +// PACKLH: PACK LH (input) tensor according to the parameters from the gemm +// config. +typedef void (*xnn_x32_pack_lh_ukernel_fn)( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, + const uint32_t* lhs, size_t lhs_stride, uint32_t* lhs_packed); + +// PACKLH Size: Size of packed buffer required. +typedef size_t (*xnn_x32_pack_lh_size_fn)(size_t m, size_t k, size_t mr, + size_t kr, size_t sr); + // FILL: FILL array with value typedef void (*xnn_fill_ukernel_fn)( @@ -2967,7 +2977,7 @@ struct xnn_hmp_qp8gemm_bl_ukernel { // Largest GEMM/IGEMM MR used in init.c is 16 (x86 AVX512AMX). // Largest GEMM/IGEMM MR is 8 in e2e benchmarks. -#define XNN_MAX_MR 16 +#define XNN_MAX_MR 32 struct gemm_fused_ukernels { union { diff --git a/src/xnnpack/node-type-defs.h b/src/xnnpack/node-type-defs.h index 08d7fe411b5..2ac4f508f80 100644 --- a/src/xnnpack/node-type-defs.h +++ b/src/xnnpack/node-type-defs.h @@ -50,6 +50,7 @@ XNN_ENUM_ITEM(xnn_node_type_maximum2, "Maximum2") XNN_ENUM_ITEM(xnn_node_type_minimum2, "Minimum2") XNN_ENUM_ITEM(xnn_node_type_multiply2, "Multiply2") XNN_ENUM_ITEM(xnn_node_type_negate, "Negate") +XNN_ENUM_ITEM(xnn_node_type_pack_lh, "Pack LH") XNN_ENUM_ITEM(xnn_node_type_prelu, "PReLU") XNN_ENUM_ITEM(xnn_node_type_reciprocal_square_root, "Reciprocal Square Root") XNN_ENUM_ITEM(xnn_node_type_rope, "RoPE") diff --git a/src/xnnpack/operator-type-defs.h b/src/xnnpack/operator-type-defs.h index e61372a2dcd..d170d25e5ea 100644 --- a/src/xnnpack/operator-type-defs.h +++ b/src/xnnpack/operator-type-defs.h @@ -82,6 +82,8 @@ XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_f16, "Fully Connected (NC, F1 XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_f32, "Fully Connected (NC, F32)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_f32_qc4w, "Fully Connected (NC, F32, QC4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_f32_qc8w, "Fully Connected (NC, F32, QC8W)") +XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_pf32, + "Fully Connected (NC, PF32)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f16_qb4w, "Fully Connected (NC, QD8, F16, QB4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w, "Fully Connected (NC, QD8, F16, QC4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f16_qc8w, "Fully Connected (NC, QD8, F16, QC8W)") @@ -119,6 +121,7 @@ XNN_ENUM_ITEM(xnn_operator_type_minimum, "Minimum (ND)") XNN_ENUM_ITEM(xnn_operator_type_multiply, "Multiply (ND)") XNN_ENUM_ITEM(xnn_operator_type_negate_nc_f16, "Negate (NC, F16)") XNN_ENUM_ITEM(xnn_operator_type_negate_nc_f32, "Negate (NC, F32)") +XNN_ENUM_ITEM(xnn_operator_type_pack_lh_x32, "Pack LH (X32)") XNN_ENUM_ITEM(xnn_operator_type_prelu_nc_f16, "PReLU (NC, F16)") XNN_ENUM_ITEM(xnn_operator_type_prelu_nc_f32, "PReLU (NC, F32)") XNN_ENUM_ITEM(xnn_operator_type_reciprocal_square_root_nc_f16, "Reciprocal Square Root (NC, F16)") diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index 0fbb6c36670..73808204712 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -355,6 +355,7 @@ struct xnn_operator { enum xnn_attention_logits_cap_type cap_type; struct xnn_attention_logits_cap_tanh_params cap_params; } attention; // For attention operator. + const struct xnn_pack_lh_config* pack_lh_config; }; struct compute_parameters compute[XNN_MAX_COMPUTE_INVOCATIONS]; @@ -411,6 +412,7 @@ struct xnn_operator { struct unpooling_context unpooling; struct vmulcaddc_context vmulcaddc; struct rope_context rope; + struct x32_pack_lh_context x32_pack_lh; } context; struct xnn_code_cache* code_cache; diff --git a/src/xnnpack/pack-lh.h b/src/xnnpack/pack-lh.h new file mode 100644 index 00000000000..77b59478a3f --- /dev/null +++ b/src/xnnpack/pack-lh.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +#include "xnnpack/common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define XNN_UKERNEL(arch_flags, ukernel, size_fn) \ + XNN_INTERNAL void ukernel(size_t m, size_t k, size_t mr, size_t kr, \ + size_t sr, size_t m_idx_start, const float* x, \ + size_t x_stride, void* y); \ + \ + XNN_INTERNAL size_t size_fn(size_t m, size_t k, size_t mr, size_t kr, \ + size_t sr); + +#include "x32-pack-lh/x32-pack-lh.h" + +#undef XNN_UKERNEL + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index 6f8f4e9277d..afc556c77e3 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -484,6 +484,20 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases( size_t k_stride, // size_t extra_bytes); +size_t xnn_packed_stride_kai_f32_weights_and_biases( + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + size_t extra_bytes); + +void xnn_pack_kai_f32_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params); + XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases( uint32_t flags, // const struct xnn_gemm_config* gemm_config, // diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h index d60e34bad28..d8c6eb8a188 100644 --- a/src/xnnpack/subgraph.h +++ b/src/xnnpack/subgraph.h @@ -486,6 +486,10 @@ struct xnn_runtime { enum xnn_status xnn_insert_clamp_node(xnn_subgraph_t subgraph, float output_min, float output_max, struct xnn_node *node); +enum xnn_status xnn_insert_pack_lh_node(xnn_subgraph_t subgraph, + const struct xnn_value* input, + uint32_t input_id, uint32_t* new_id); + struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph); struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph); diff --git a/test/f32-gemm-minmax.yaml b/test/f32-gemm-minmax.yaml index 1139fe8380a..7e4a17de36e 100644 --- a/test/f32-gemm-minmax.yaml +++ b/test/f32-gemm-minmax.yaml @@ -973,3 +973,9 @@ init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w k-block: 1 + +- name: xnn_f32_gemm_minmax_ukernel_32x32__neonsme2 + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_kai_f32_weights_and_biases + packed-stride: xnn_packed_stride_kai_f32_weights_and_biases + k-block: 2 diff --git a/tools/update-microkernels.py b/tools/update-microkernels.py index 14ebc1621d3..e6ecd944c78 100755 --- a/tools/update-microkernels.py +++ b/tools/update-microkernels.py @@ -183,6 +183,8 @@ def main(args): c_microkernels_per_isa['neondot_aarch64'] = list() c_microkernels_per_isa['neonfma_aarch64'] = list() c_microkernels_per_isa['neonfp16arith_aarch64'] = list() + c_microkernels_per_isa['neonsme'] = list() + c_microkernels_per_isa['neonsme2'] = list() asm_microkernels_per_arch = {arch: [] for arch in _ARCH_LIST} microkernel_name_to_filename = dict() for root, _, files in os.walk(src_dir, topdown=False):