Skip to content

Commit

Permalink
[Fast Packing] Add packing ukernels to gemm config
Browse files Browse the repository at this point in the history
  • Loading branch information
mcr229 committed Dec 9, 2024
1 parent 35b59d9 commit 3f03824
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 6 deletions.
5 changes: 4 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,10 @@ xnnpack_cc_library(

xnnpack_cxx_library(
name = "packing",
srcs = ["src/reference/packing.cc"],
srcs = [
"src/reference/packing.cc",
"src/packw.c"
],
hdrs = ["src/xnnpack/pack.h"],
defines = xnnpack_configurable_defines(),
deps = [
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ ADD_LIBRARY(indirection OBJECT src/indirection.c)
ADD_LIBRARY(logging OBJECT ${LOGGING_SRCS})
ADD_LIBRARY(microparams-init OBJECT src/microparams-init.c)
ADD_LIBRARY(normalization OBJECT src/normalization.c)
ADD_LIBRARY(packing OBJECT src/reference/packing.cc)
ADD_LIBRARY(packing OBJECT src/reference/packing.cc src/packw.c)
TARGET_LINK_LIBRARIES(hardware-config PRIVATE xnnpack-base logging)
TARGET_LINK_LIBRARIES(indirection PRIVATE xnnpack-base)
TARGET_LINK_LIBRARIES(logging PRIVATE xnnpack-base)
Expand Down
4 changes: 2 additions & 2 deletions cmake/gen/scalar_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS
src/f32-vunary/gen/f32-vabs-scalar.c
src/f32-vunary/gen/f32-vneg-scalar.c
src/f32-vunary/gen/f32-vsqr-scalar.c
src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c
src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c
src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c
src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c
Expand Down Expand Up @@ -541,8 +543,6 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS
src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c
src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c
src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c
src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c
src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c
src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c
src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c
src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c
Expand Down
4 changes: 2 additions & 2 deletions gen/scalar_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ PROD_SCALAR_MICROKERNEL_SRCS = [
"src/f32-vunary/gen/f32-vabs-scalar.c",
"src/f32-vunary/gen/f32-vneg-scalar.c",
"src/f32-vunary/gen/f32-vsqr-scalar.c",
"src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c",
"src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c",
"src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c",
"src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c",
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c",
Expand Down Expand Up @@ -538,8 +540,6 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [
"src/f32-vsigmoid/gen/f32-vsigmoid-scalar-rr2-p5-div-u4.c",
"src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u2.c",
"src/f32-vsqrt/gen/f32-vsqrt-scalar-sqrt-u4.c",
"src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c",
"src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c",
"src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x2-minmax-scalar.c",
"src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x4-minmax-scalar.c",
"src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-1x8-minmax-scalar.c",
Expand Down
3 changes: 3 additions & 0 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.nr = 16;
qd8_f32_qb4w_gemm_config.log2_kr = 2;
qd8_f32_qb4w_gemm_config.planes = 2;
qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases;
#endif // XNN_ENABLE_ARM_DOTPROD
} else {
qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane);
Expand Down Expand Up @@ -1824,6 +1825,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.nr = 16;
qd8_f32_qb4w_gemm_config.log2_kr = 3;
qd8_f32_qb4w_gemm_config.planes = 2;
qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c8_weights_and_biases;
#endif // XNN_ENABLE_ARM_I8MM
} else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) {
#if XNN_ENABLE_ARM_DOTPROD
Expand All @@ -1834,6 +1836,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.nr = 16;
qd8_f32_qb4w_gemm_config.log2_kr = 2;
qd8_f32_qb4w_gemm_config.planes = 2;
qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases;
#endif // XNN_ENABLE_ARM_DOTPROD
} else {
qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane);
Expand Down
105 changes: 105 additions & 0 deletions src/packw.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@

// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.


#include "xnnpack.h"
#include "xnnpack/common.h"
#include "xnnpack/config-types.h"
#include "xnnpack/log.h"
#include "xnnpack/math.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microparams.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/packw.h"
#include "xnnpack/pack.h"
#include "xnnpack/unaligned.h"

void xnn_pack_qb4_x16c8_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
// No packing ukernel for gio
return xnn_pack_qb4_weights_and_biases(
flags, gemm_config, input_channels, output_channels, groups,
block_size, k_stride, accumulator_init, weights, init_extra_data0_fn,
extra_data0, extra_data0_element_size, init_extra_data1_fn, extra_data1,
extra_data1_element_size, packed_weights_ptr, params);
}
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t planes = gemm_config->planes;

const size_t extra_bytes_bl = sizeof(uint16_t);
const size_t extra_bytes_n = sizeof(uint32_t);

xnn_qb4_packw_gemm_goi_ukernel_x16c8__scalar(
/*g=*/groups,
/*nc=*/output_channels,
/*kc=*/input_channels,
/*nr=*/nr,
/*kr=*/kr,
/*sr=*/sr,
/*bl=*/block_size,
/*k=*/(const uint8_t*)weights,
/*bias=*/(const int32_t*)accumulator_init,
/*scale=*/(const xnn_bfloat16*)extra_data1,
/*packed_weights=*/(int8_t*)packed_weights_ptr,
/*extra_bytes_bl=*/nr * extra_bytes_bl,
/*extra_bytes_n=*/nr * extra_bytes_n,
/*params*/(const struct xnn_qs8_qc4w_packing_params *)params);
}

void xnn_pack_qb4_x16c4_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
// No packing ukernel for gio
return xnn_pack_qb4_weights_and_biases(
flags, gemm_config, input_channels, output_channels, groups,
block_size, k_stride, accumulator_init, weights, init_extra_data0_fn,
extra_data0, extra_data0_element_size, init_extra_data1_fn, extra_data1,
extra_data1_element_size, packed_weights_ptr, params);
}
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t planes = gemm_config->planes;

const size_t extra_bytes_bl = sizeof(uint16_t);
const size_t extra_bytes_n = sizeof(uint32_t);

xnn_qb4_packw_gemm_goi_ukernel_x16c4__scalar(
/*g=*/groups,
/*nc=*/output_channels,
/*kc=*/input_channels,
/*nr=*/nr,
/*kr=*/kr,
/*sr=*/sr,
/*bl=*/block_size,
/*k=*/(const uint8_t*)weights,
/*bias=*/(const int32_t*)accumulator_init,
/*scale=*/(const xnn_bfloat16*)extra_data1,
/*packed_weights=*/(int8_t*)packed_weights_ptr,
/*extra_bytes_bl=*/nr * extra_bytes_bl,
/*extra_bytes_n=*/nr * extra_bytes_n,
/*params*/(const struct xnn_qs8_qc4w_packing_params *)params);
}
38 changes: 38 additions & 0 deletions src/xnnpack/pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,44 @@ XNN_INTERNAL void xnn_pack_qb4_weights_and_biases(
void* packed_weights_ptr, //
const void* params);

XNN_INTERNAL void xnn_pack_qb4_x16c4_weights_and_biases(
uint32_t flags, //
const struct xnn_gemm_config* gemm_config, //
size_t input_channels, //
size_t output_channels, //
size_t groups, //
size_t block_size, //
size_t k_stride, //
const void* accumulator_init, //
const void* weights, //
xnn_init_scale_params_fn init_extra_data0_fn, //
const void* extra_data0, //
size_t extra_data0_element_size, //
xnn_init_scale_params_fn init_extra_data1_fn, //
const void* extra_data1, //
size_t extra_data1_element_size, //
void* packed_weights_ptr, //
const void* params);

XNN_INTERNAL void xnn_pack_qb4_x16c8_weights_and_biases(
uint32_t flags, //
const struct xnn_gemm_config* gemm_config, //
size_t input_channels, //
size_t output_channels, //
size_t groups, //
size_t block_size, //
size_t k_stride, //
const void* accumulator_init, //
const void* weights, //
xnn_init_scale_params_fn init_extra_data0_fn, //
const void* extra_data0, //
size_t extra_data0_element_size, //
xnn_init_scale_params_fn init_extra_data1_fn, //
const void* extra_data1, //
size_t extra_data1_element_size, //
void* packed_weights_ptr, //
const void* params);

XNN_INTERNAL size_t xnn_packed_stride_qb4_weights_and_biases(
const struct xnn_gemm_config* gemm_config, //
size_t k, //
Expand Down
8 changes: 8 additions & 0 deletions tools/update-microkernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,14 @@ def main(args):
content = config_file.read()
microkernels = re.findall(_MICROKERNEL_NAME_REGEX, content)
prod_microkernels.update(microkernels)
# Also check prod packing ukernels in packw.c
with open(
os.path.join(src_dir, 'packw.c'), 'r', encoding='utf-8'
) as packw_file:
content = packw_file.read()
microkernels = re.findall(_MICROKERNEL_NAME_REGEX, content)
prod_microkernels.update(microkernels)

prod_microkernels = set(
map(microkernel_name_to_filename.get, prod_microkernels)
)
Expand Down

0 comments on commit 3f03824

Please sign in to comment.