From 3f0382475af170da339421c93ccb70d990ba1637 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 26 Nov 2024 16:34:01 -0500 Subject: [PATCH] [Fast Packing] Add packing ukernels to gemm config --- BUILD.bazel | 5 +- CMakeLists.txt | 2 +- cmake/gen/scalar_microkernels.cmake | 4 +- gen/scalar_microkernels.bzl | 4 +- src/configs/gemm-config.c | 3 + src/packw.c | 105 ++++++++++++++++++++++++++++ src/xnnpack/pack.h | 38 ++++++++++ tools/update-microkernels.py | 8 +++ 8 files changed, 163 insertions(+), 6 deletions(-) create mode 100644 src/packw.c diff --git a/BUILD.bazel b/BUILD.bazel index 0a7cdef9b2f..245a722a444 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -783,7 +783,10 @@ xnnpack_cc_library( xnnpack_cxx_library( name = "packing", - srcs = ["src/reference/packing.cc"], + srcs = [ + "src/reference/packing.cc", + "src/packw.c" + ], hdrs = ["src/xnnpack/pack.h"], defines = xnnpack_configurable_defines(), deps = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b1688ed70b..18a8d92ebaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -793,7 +793,7 @@ ADD_LIBRARY(indirection OBJECT src/indirection.c) ADD_LIBRARY(logging OBJECT ${LOGGING_SRCS}) ADD_LIBRARY(microparams-init OBJECT src/microparams-init.c) ADD_LIBRARY(normalization OBJECT src/normalization.c) -ADD_LIBRARY(packing OBJECT src/reference/packing.cc) +ADD_LIBRARY(packing OBJECT src/reference/packing.cc src/packw.c) TARGET_LINK_LIBRARIES(hardware-config PRIVATE xnnpack-base logging) TARGET_LINK_LIBRARIES(indirection PRIVATE xnnpack-base) TARGET_LINK_LIBRARIES(logging PRIVATE xnnpack-base) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index be694e9477e..efd75af706e 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -135,6 +135,8 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/f32-vunary/gen/f32-vabs-scalar.c src/f32-vunary/gen/f32-vneg-scalar.c src/f32-vunary/gen/f32-vsqr-scalar.c + src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c + src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c @@ -541,8 +543,6 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c - src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c - src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index fe920fe5c76..f75030ff538 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -131,6 +131,8 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/f32-vunary/gen/f32-vabs-scalar.c", "src/f32-vunary/gen/f32-vneg-scalar.c", "src/f32-vunary/gen/f32-vsqr-scalar.c", + "src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c", + "src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c", "src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c", "src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c", "src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c", @@ -538,8 +540,6 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c", "src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c", "src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c", - "src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c", - "src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c", "src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c", diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index d5236d1f5d4..357ecbb1a8d 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1795,6 +1795,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 16; qd8_f32_qb4w_gemm_config.log2_kr = 2; qd8_f32_qb4w_gemm_config.planes = 2; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases; #endif // XNN_ENABLE_ARM_DOTPROD } else { qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane); @@ -1824,6 +1825,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 16; qd8_f32_qb4w_gemm_config.log2_kr = 3; qd8_f32_qb4w_gemm_config.planes = 2; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c8_weights_and_biases; #endif // XNN_ENABLE_ARM_I8MM } else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { #if XNN_ENABLE_ARM_DOTPROD @@ -1834,6 +1836,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 16; qd8_f32_qb4w_gemm_config.log2_kr = 2; qd8_f32_qb4w_gemm_config.planes = 2; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases; #endif // XNN_ENABLE_ARM_DOTPROD } else { qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane); diff --git a/src/packw.c b/src/packw.c new file mode 100644 index 00000000000..9d53896b902 --- /dev/null +++ b/src/packw.c @@ -0,0 +1,105 @@ + +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include "xnnpack.h" +#include "xnnpack/common.h" +#include "xnnpack/config-types.h" +#include "xnnpack/log.h" +#include "xnnpack/math.h" +#include "xnnpack/microfnptr.h" +#include "xnnpack/microparams.h" +#include "xnnpack/microparams-init.h" +#include "xnnpack/packw.h" +#include "xnnpack/pack.h" +#include "xnnpack/unaligned.h" + +void xnn_pack_qb4_x16c8_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + // No packing ukernel for gio + return xnn_pack_qb4_weights_and_biases( + flags, gemm_config, input_channels, output_channels, groups, + block_size, k_stride, accumulator_init, weights, init_extra_data0_fn, + extra_data0, extra_data0_element_size, init_extra_data1_fn, extra_data1, + extra_data1_element_size, packed_weights_ptr, params); + } + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t planes = gemm_config->planes; + + const size_t extra_bytes_bl = sizeof(uint16_t); + const size_t extra_bytes_n = sizeof(uint32_t); + + xnn_qb4_packw_gemm_goi_ukernel_x16c8__scalar( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/(const int32_t*)accumulator_init, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/(int8_t*)packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); +} + +void xnn_pack_qb4_x16c4_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + // No packing ukernel for gio + return xnn_pack_qb4_weights_and_biases( + flags, gemm_config, input_channels, output_channels, groups, + block_size, k_stride, accumulator_init, weights, init_extra_data0_fn, + extra_data0, extra_data0_element_size, init_extra_data1_fn, extra_data1, + extra_data1_element_size, packed_weights_ptr, params); + } + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t planes = gemm_config->planes; + + const size_t extra_bytes_bl = sizeof(uint16_t); + const size_t extra_bytes_n = sizeof(uint32_t); + + xnn_qb4_packw_gemm_goi_ukernel_x16c4__scalar( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/(const int32_t*)accumulator_init, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/(int8_t*)packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); +} diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index c58617bec60..7a6f6ba9b15 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -453,6 +453,44 @@ XNN_INTERNAL void xnn_pack_qb4_weights_and_biases( void* packed_weights_ptr, // const void* params); +XNN_INTERNAL void xnn_pack_qb4_x16c4_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + +XNN_INTERNAL void xnn_pack_qb4_x16c8_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + XNN_INTERNAL size_t xnn_packed_stride_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // diff --git a/tools/update-microkernels.py b/tools/update-microkernels.py index e6ecd944c78..f153f4d25c3 100755 --- a/tools/update-microkernels.py +++ b/tools/update-microkernels.py @@ -297,6 +297,14 @@ def main(args): content = config_file.read() microkernels = re.findall(_MICROKERNEL_NAME_REGEX, content) prod_microkernels.update(microkernels) + # Also check prod packing ukernels in packw.c + with open( + os.path.join(src_dir, 'packw.c'), 'r', encoding='utf-8' + ) as packw_file: + content = packw_file.read() + microkernels = re.findall(_MICROKERNEL_NAME_REGEX, content) + prod_microkernels.update(microkernels) + prod_microkernels = set( map(microkernel_name_to_filename.get, prod_microkernels) )