From 23dcc66cfe67741b53f85de6496bcb7b6ddce73e Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sun, 1 Dec 2024 23:12:04 -0500 Subject: [PATCH 1/3] [Packing Refactor] Move all Blockwise Packing to pack_weights_and_bias --- src/configs/gemm-config.c | 9 +- src/operators/batch-matrix-multiply-nc.c | 10 +- src/operators/convolution-nhwc.c | 1 + src/operators/fully-connected-nc.c | 96 ++------------ src/reference/packing.cc | 152 ++++++++++++++++++++--- src/xnnpack/config-types.h | 2 - src/xnnpack/microfnptr.h | 2 + src/xnnpack/pack.h | 42 ++++++- 8 files changed, 199 insertions(+), 115 deletions(-) diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 78bd447a9b2..05be86767dd 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1539,7 +1539,8 @@ static void init_qd8_f16_qc4w_gemm_config(void) { } static void init_qd8_f16_qb4w_gemm_config(void) { - qd8_f16_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qd8_f16_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qd8_f16_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -1845,7 +1846,8 @@ static void init_qp8_f32_qb4w_gemm_config(void) { } static void init_qdu8_f32_qb4w_gemm_config(void) { - qdu8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qdu8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qdu8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); @@ -1884,7 +1886,8 @@ static void init_qdu8_f32_qb4w_gemm_config(void) { } static void init_qd8_f32_qb4w_gemm_config(void) { - qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qd8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index c19825035a6..f85338dcbe2 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -182,7 +182,9 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights( // Pack the weights. if (gemm_config->pack_weights_and_biases) { gemm_config->pack_weights_and_biases(flags, gemm_config, k, n, - /*groups=*/batch_size_b, k_stride, + /*groups=*/batch_size_b, + /*unused_block_size=*/0, + /*kstride=*/k_stride, /*accumulator_init=*/NULL, /*weights=*/data_b, /*int_extra_data0_fn=*/NULL, @@ -311,7 +313,7 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w( const size_t weights_stride = gemm_config->packed_stride_weights_and_biases ? gemm_config->packed_stride_weights_and_biases( - gemm_config, k, k_stride, extra_bytes) + gemm_config, k,/*unused_blocksize=*/0, k_stride, extra_bytes) : (k_stride << XNN_LOG2_SIZEOF_INT8_T) + extra_bytes + sizeof(int32_t); batch_matrix_multiply_op->weights_stride = weights_stride; @@ -347,7 +349,9 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w( batch_matrix_multiply_op->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS, gemm_config, /*input_channels=*/k, /*output_channels=*/n, - /*groups=*/batch_size_b, k_stride, + /*groups=*/batch_size_b, + /*unused_block_size=*/0, + /*k_stride=*/k_stride, /*accumulator_init=*/NULL, /*weights=*/data_b, /*int_extra_data0_fn=*/ diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index bfc0be24eb7..3f3a22b3720 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -372,6 +372,7 @@ static enum xnn_status create_gemm_or_igemm( gemm_config->pack_weights_and_biases( flags, gemm_config, group_input_channels, group_output_channels, groups, + /*unused_block_size*/0, k_stride, /*accumulator_init=*/bias, /*weights=*/kernel, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 528d58a31f9..83e8bc9e743 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -44,7 +44,6 @@ static enum xnn_status create_fully_connected_nc( const void* bias, uint32_t flags, size_t block_size, - size_t extra_bl_bytes, const uint16_t* blockwise_kernel_scale_params, uint32_t log2_input_element_size, uint32_t log2_filter_element_size, @@ -52,7 +51,6 @@ static enum xnn_status create_fully_connected_nc( uint32_t bias_element_size, xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio_w, xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w, - xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi_bl_w, const void* packing_params, int packed_weights_padding_byte, size_t extra_weights_bytes, @@ -155,7 +153,7 @@ static enum xnn_status create_fully_connected_nc( const size_t weights_stride = gemm_config->packed_stride_weights_and_biases ? gemm_config->packed_stride_weights_and_biases( - gemm_config, input_channels, block_wise ? block_size : k_stride, extra_weights_bytes) + gemm_config, input_channels, block_size, k_stride, extra_weights_bytes) : (k_stride << log2_filter_element_size) + bias_element_size + extra_weights_bytes + block_scale_bytes; const size_t packed_weights_size = n_stride * weights_stride; @@ -192,7 +190,8 @@ static enum xnn_status create_fully_connected_nc( gemm_config->pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, /*groups=*/1, - block_wise ? block_size : k_stride, + /*block_wise=*/block_size, + /*kstride=*/k_stride, /*accumulator_init=*/bias, /*weights=*/kernel, /*int_extra_data0_fn=*/(xnn_init_scale_params_fn)init_scale_params, @@ -204,16 +203,6 @@ static enum xnn_status create_fully_connected_nc( /*extra_data1_size=*/init_kernel_scale_params != NULL ? sizeof(float) : 0, /*packed_weights_ptr=*/weights_ptr, packing_params); - - if (block_wise && bias != NULL) { - void* weights_start = (void*) ((uintptr_t) weights_ptr + - gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); - weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; - xnn_init_qs8_qc8w_scale_fp32_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, - bias, weights_start); - } } else { if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { pack_gemm_gio_w( @@ -225,24 +214,13 @@ static enum xnn_status create_fully_connected_nc( gemm_config->nr * extra_weights_bytes, packing_params); } else { - if (block_wise) { - pack_gemm_goi_bl_w( - /*groups=*/1, output_channels, input_channels, - nr, kr, sr, block_size, - kernel, /*bias=*/NULL, /*scale=*/blockwise_kernel_scale_params, - weights_ptr, - gemm_config->nr * extra_bl_bytes, - gemm_config->nr * extra_weights_bytes, - packing_params); - } else { - pack_gemm_goi_w( - /*groups=*/1, output_channels, input_channels, - nr, kr, sr, - kernel, bias, /*scale=*/NULL, - weights_ptr, - gemm_config->nr * extra_weights_bytes, - packing_params); - } + pack_gemm_goi_w( + /*groups=*/1, output_channels, input_channels, + nr, kr, sr, + kernel, bias, /*scale=*/NULL, + weights_ptr, + gemm_config->nr * extra_weights_bytes, + packing_params); } if (kernel_scale_params != NULL) { assert(init_kernel_scale_params != NULL); @@ -267,32 +245,6 @@ static enum xnn_status create_fully_connected_nc( gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, scale_params, weights); } - - if (block_wise) { - // Fill in kernel scale. - void* weights_start = (void*) ((uintptr_t) weights_ptr + - gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); - - const size_t block_stride = /*weights*/block_size / 2 + sizeof(uint16_t); - - xnn_init_blockwise_scale_bf16_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, - gemm_config->nr * weights_stride, - /*num_blocks=*/num_blocks, - /*block_stride=*/gemm_config->nr * block_stride, - 0, - (const xnn_bfloat16*)blockwise_kernel_scale_params, weights_start); - - // Fill in bias. - if (bias != NULL) { - weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; - xnn_init_qs8_qc8w_scale_fp32_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, - bias, weights_start); - } - } } if (use_weights_cache(fully_connected_op)) { @@ -397,7 +349,6 @@ enum xnn_status create_fully_connected_nc_f16( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_HALF, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_HALF, @@ -405,7 +356,6 @@ enum xnn_status create_fully_connected_nc_f16( /*bias_element_size=*/sizeof(uint16_t), pack_gemm_gio_w, pack_gemm_goi_w, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -533,7 +483,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc4w( input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -541,7 +490,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -708,7 +656,6 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -716,7 +663,6 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, /*packing_params=*/&packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -801,7 +747,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qc4w( input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -809,7 +754,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -1119,7 +1063,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -1127,7 +1070,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -1244,7 +1186,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -1252,7 +1193,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -1375,7 +1315,6 @@ enum xnn_status create_fully_connected_nc_qdx8_f32_qc8w( input_stride, output_stride, kernel, NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1383,7 +1322,6 @@ enum xnn_status create_fully_connected_nc_qdx8_f32_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -1506,7 +1444,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc8w( input_stride, output_stride, kernel, NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1514,7 +1451,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -1662,7 +1598,6 @@ enum xnn_status create_fully_connected_nc_f32( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_FLOAT, @@ -1670,7 +1605,6 @@ enum xnn_status create_fully_connected_nc_f32( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -1807,7 +1741,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, // Pass 1 byte even though it is half byte, we handle the division via filter_is_nibble == true. @@ -1816,7 +1749,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) NULL, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -1900,7 +1832,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc8w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1908,7 +1839,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2000,7 +1930,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -2008,7 +1937,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2107,7 +2035,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -2115,7 +2042,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2207,7 +2133,6 @@ enum xnn_status xnn_create_fully_connected_nc_qu8( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -2215,7 +2140,6 @@ enum xnn_status xnn_create_fully_connected_nc_qu8( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/kernel_zero_point, /*extra_weights_bytes=*/0, diff --git a/src/reference/packing.cc b/src/reference/packing.cc index 60f64b94114..631d075612f 100644 --- a/src/reference/packing.cc +++ b/src/reference/packing.cc @@ -19,6 +19,7 @@ #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" +#include "xnnpack/microparams-init.h" #include "xnnpack/pack.h" #include "xnnpack/unaligned.h" @@ -1413,8 +1414,8 @@ void pack_weights_and_biases(uint32_t flags, // } size_t xnn_packed_stride_qs8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1423,7 +1424,7 @@ size_t xnn_packed_stride_qs8_weights_and_biases( void xnn_pack_qs8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, + size_t input_channels, size_t output_channels, size_t groups, size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, @@ -1433,7 +1434,7 @@ void xnn_pack_qs8_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_gemm_gio_w, @@ -1444,8 +1445,8 @@ void xnn_pack_qs8_weights_and_biases( } size_t xnn_packed_stride_qs4_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1455,7 +1456,7 @@ size_t xnn_packed_stride_qs4_weights_and_biases( void xnn_pack_qs4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1464,7 +1465,7 @@ void xnn_pack_qs4_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, weights_stride, @@ -1475,9 +1476,111 @@ void xnn_pack_qs4_weights_and_biases( extra_data1_element_size, packed_weights_ptr, extra_bytes, params); } +size_t xnn_packed_stride_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, + size_t k_stride, size_t extra_bytes) { + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const uint32_t nr = gemm_config->nr; + const size_t planes = gemm_config->planes; + + size_t input_channels = round_up_po2(k, planes); + + size_t block_scale_bytes = 0; + size_t num_blocks = 0; + const bool block_wise = (block_size != 0); + if (block_wise) { + num_blocks = input_channels / block_size; + block_scale_bytes += num_blocks * sizeof(uint16_t); + } + + const size_t bias_element_size = sizeof(int32_t); + const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; + return (k_stride << log2_filter_element_size) + bias_element_size + + extra_bytes + block_scale_bytes; +} + +void xnn_pack_qb4_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t planes = gemm_config->planes; + + const size_t extra_bytes_bl = sizeof(uint16_t); + const size_t extra_bytes_n = sizeof(uint32_t); + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + xnn_pack_qs8_qb4w_gemm_gio_w( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*k_stride=*/k_stride, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/NULL, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); + } else { + xnn_pack_qs8_qb4w_gemm_goi_w( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/NULL, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); + } + + // fill in kernel scales + const size_t num_blocks = input_channels / block_size; + const size_t weights_stride = xnn_packed_stride_qb4_weights_and_biases(gemm_config, input_channels, block_size, k_stride, extra_bytes_n); + void* weights_start = (void*) ((uintptr_t) packed_weights_ptr + + nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); + + const size_t block_stride = /*weights*/block_size / 2 + sizeof(uint16_t); + xnn_init_blockwise_scale_bf16_params( + output_channels, nr, nr, + nr * weights_stride, + nr * weights_stride, + /*num_blocks=*/num_blocks, + /*block_stride=*/gemm_config->nr * block_stride, + 0, + (const xnn_bfloat16*)extra_data1, weights_start); + + // fill in bias if not null + if (accumulator_init != nullptr) { + weights_start = (void*) ((uintptr_t) packed_weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; + xnn_init_qs8_qc8w_scale_fp32_params( + output_channels, gemm_config->nr, gemm_config->nr, + gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, + (const float*)accumulator_init, weights_start); + } +} + size_t xnn_packed_stride_qu8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1487,7 +1590,7 @@ size_t xnn_packed_stride_qu8_weights_and_biases( void xnn_pack_qu8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1496,7 +1599,7 @@ void xnn_pack_qu8_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qu8_gemm_gio_w, @@ -1508,7 +1611,7 @@ void xnn_pack_qu8_weights_and_biases( #if XNN_ENABLE_KLEIDIAI size_t xnn_packed_stride_kai_qs4_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; @@ -1519,7 +1622,7 @@ size_t xnn_packed_stride_kai_qs4_weights_and_biases( void xnn_pack_kai_qs4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1659,7 +1762,7 @@ void xnn_pack_kai_qs8_weights_and_biases( } size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, size_t unused_block_size, size_t k, size_t unused_k_stride, size_t extra_bytes) { size_t ret_val = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(k) / @@ -1732,7 +1835,7 @@ void transpose_weights(const float* in, float* out, size_t height, void xnn_pack_kai_f32_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1784,7 +1887,7 @@ void xnn_pack_kai_f32_weights_and_biases( size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, - size_t extra_bytes) { + size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; const uint32_t nr = gemm_config->nr; @@ -1802,7 +1905,7 @@ size_t xnn_packed_stride_kai_qb4_weights_and_biases( void xnn_pack_kai_qb4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t block_size, const void* accumulator_init, const void* weights, + size_t block_size, size_t unused_k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1850,6 +1953,19 @@ void xnn_pack_kai_qb4_weights_and_biases( /*rhs_packed*/ packed_weights_ptr, /*extra_bytes=*/0, &kai_params); } + + // init bias + const size_t weights_stride = xnn_packed_stride_kai_qb4_weights_and_biases( + gemm_config, input_channels, block_size, unused_k_stride, 0); + if (accumulator_init != NULL) { + void* weights_start = (void*) ((uintptr_t) packed_weights_ptr + + nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); + weights_start = (void*) ((uintptr_t) packed_weights_ptr + nr * (weights_stride - sizeof(float))) ; + xnn_init_qs8_qc8w_scale_fp32_params( + output_channels, nr, nr, + nr * weights_stride, nr * weights_stride, 0, + (const float*)accumulator_init, weights_start); + } } #endif // XNN_ENABLE_KLEIDIAI diff --git a/src/xnnpack/config-types.h b/src/xnnpack/config-types.h index 3abb6e22844..7bbe8908108 100644 --- a/src/xnnpack/config-types.h +++ b/src/xnnpack/config-types.h @@ -194,8 +194,6 @@ struct xnn_gemm_config { xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio; // Deprecated. Use pack_weights_and_biases instead. xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi; - // TODO(b/346765736): Use pack_weights_and_biases instead. - xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi_bl; xnn_pack_conv_goki_w_fn pack_igemm_goki; xnn_pack_conv_kgo_w_fn pack_igemm_kgo; xnn_pack_deconv_goki_w_fn pack_deconv_goki; diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 94bf319b022..48762f6fc39 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -2244,6 +2244,7 @@ typedef void (*xnn_pack_weights_and_biases_fn)( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // // We tile packing by output channels, in GIO layout, the k (row) index // needs to be able to skip by the actual number of output channels, and not // just the argument nc. E.g. if weights is 1x3x5, and nr is 2, we tile the @@ -2270,6 +2271,7 @@ typedef void (*xnn_pack_weights_and_biases_fn)( typedef size_t (*xnn_packed_stride_weights_and_biases_fn)( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index 6791f846f48..a15a2759365 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -387,6 +387,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -402,6 +403,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -412,6 +414,7 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -427,6 +430,33 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // + size_t k_stride, // + size_t extra_bytes); + +XNN_INTERNAL void xnn_pack_qb4_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + +XNN_INTERNAL size_t xnn_packed_stride_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -436,6 +466,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -451,6 +482,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qu8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -461,6 +493,7 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -476,6 +509,7 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -508,8 +542,8 @@ size_t xnn_packed_stride_kai_f16_weights_and_biases( size_t extra_bytes); size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, - size_t extra_bytes); + const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, + size_t unused_k_stride, size_t extra_bytes); void xnn_pack_kai_f16_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, @@ -524,7 +558,7 @@ void xnn_pack_kai_f16_weights_and_biases( void xnn_pack_kai_f32_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -538,6 +572,7 @@ XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases( size_t output_channels, // size_t groups, // size_t block_size, // + size_t k_stride, // const void* accumulator_init, // const void* weights, // xnn_init_scale_params_fn init_extra_data0_fn, // @@ -553,6 +588,7 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // size_t block_size, // + size_t k_stride, // size_t extra_bytes); #endif // XNN_ENABLE_KLEIDIAI From 8ea3af3f796c065971003e531f76ec1f808f06d6 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 5 Dec 2024 15:29:04 -0500 Subject: [PATCH 2/3] add back pack_gemm_goi_bl ukernel --- src/xnnpack/config-types.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/xnnpack/config-types.h b/src/xnnpack/config-types.h index 7bbe8908108..3abb6e22844 100644 --- a/src/xnnpack/config-types.h +++ b/src/xnnpack/config-types.h @@ -194,6 +194,8 @@ struct xnn_gemm_config { xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio; // Deprecated. Use pack_weights_and_biases instead. xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi; + // TODO(b/346765736): Use pack_weights_and_biases instead. + xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi_bl; xnn_pack_conv_goki_w_fn pack_igemm_goki; xnn_pack_conv_kgo_w_fn pack_igemm_kgo; xnn_pack_deconv_goki_w_fn pack_deconv_goki; From c6fff3c3d7a537cbeb66ae73aa438d1e542253a4 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 10 Jan 2025 15:11:05 -0800 Subject: [PATCH 3/3] merge conflicts and fix all failures --- bench/gemm-benchmark.cc | 9 ++-- src/operators/fully-connected-nc.c | 3 +- src/reference/packing.cc | 23 ++++----- src/xnnpack/pack.h | 76 ++++++++++++++++++++---------- test/gemm-microkernel-tester.cc | 15 +++--- 5 files changed, 78 insertions(+), 48 deletions(-) diff --git a/bench/gemm-benchmark.cc b/bench/gemm-benchmark.cc index d7eb7de4383..d8b152dc5de 100644 --- a/bench/gemm-benchmark.cc +++ b/bench/gemm-benchmark.cc @@ -735,7 +735,7 @@ void GEMMBenchmark(benchmark::State& state, gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr)); const size_t packed_w_stride = - packed_stride(&gemm_config, kc, /*k_stride=*/kc, /*extra_bytes=*/0); + packed_stride(&gemm_config, kc, /*unused_block_size=*/0, /*k_stride=*/kc, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(nc, nr); const size_t c_elements = mc * nc; @@ -760,7 +760,7 @@ void GEMMBenchmark(benchmark::State& state, const xnn_qs8_qc4w_packing_params packing_params = {/*input_zero_point=*/1, /*kernel_zero_point=*/8}; pack_weights(/*flags=*/0, &gemm_config, kc, nc, - /*groups=*/1, /*k_stride=*/kc, + /*groups=*/1, /*unused_block_size=*/0, /*k_stride=*/kc, /*accumulator_init=*/nullptr, /*weights=*/k.data(), /*int_extra_data0_fn=*/nullptr, @@ -852,7 +852,7 @@ void GEMMBenchmark(benchmark::State& state, gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr)); const size_t packed_w_stride = - packed_stride(&gemm_config, k2, /*k_stride=*/bl, /*extra_bytes=*/0); + packed_stride(&gemm_config, k2, /*block_size=*/bl, /*k_stride=*/kc, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(nc, nr); const size_t c_elements = mc * nc; @@ -879,7 +879,8 @@ void GEMMBenchmark(benchmark::State& state, const xnn_qs8_qc4w_packing_params packing_params = {/*input_zero_point=*/1, /*kernel_zero_point=*/8}; pack_weights(/*flags=*/0, &gemm_config, k2, nc, - /*groups=*/1, /*k_stride=*/bl, + /*groups=*/1, /*block_size=*/bl, + /*k_stride=*/kc, /*accumulator_init=*/nullptr, /*weights=*/k.data(), /*int_extra_data0_fn=*/nullptr, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 83e8bc9e743..6d4bf5fc960 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -865,7 +865,6 @@ static enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qcxw( input_channels, output_channels, input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -873,7 +872,7 @@ static enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qcxw( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn)gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn)gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, packing_params, + packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, /*init_scale_params=*/NULL, diff --git a/src/reference/packing.cc b/src/reference/packing.cc index 631d075612f..002beee3556 100644 --- a/src/reference/packing.cc +++ b/src/reference/packing.cc @@ -1350,6 +1350,7 @@ void pack_weights_and_biases(uint32_t flags, // size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t weights_stride, // xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio_w, // xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w, // @@ -1436,7 +1437,7 @@ void xnn_pack_qs8_weights_and_biases( const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( - flags, gemm_config, input_channels, output_channels, groups, + flags, gemm_config, input_channels, output_channels, groups, unused_block_size, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_gemm_gio_w, (xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qs8_gemm_goi_w, accumulator_init, weights, init_extra_data0_fn, extra_data0, extra_data0_element_size, @@ -1468,7 +1469,7 @@ void xnn_pack_qs4_weights_and_biases( gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, - weights_stride, + unused_block_size, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_qc4w_gemm_gio_w, (xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qs8_qc4w_gemm_goi_w, accumulator_init, weights, init_extra_data0_fn, extra_data0, @@ -1601,7 +1602,7 @@ void xnn_pack_qu8_weights_and_biases( const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( - flags, gemm_config, input_channels, output_channels, groups, + flags, gemm_config, input_channels, output_channels, groups, unused_block_size, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qu8_gemm_gio_w, (xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qu8_gemm_goi_w, accumulator_init, weights, init_extra_data0_fn, extra_data0, extra_data0_element_size, @@ -1664,8 +1665,8 @@ void xnn_pack_kai_qs4_weights_and_biases( } size_t xnn_packed_stride_kai_f16_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, + size_t unused_k_stride, size_t extra_bytes) { size_t ret_val = kai_get_rhs_packed_stride_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(k) / kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(); @@ -1682,8 +1683,8 @@ void transpose_weights_x16(const xnn_float16* in, xnn_float16* out, } size_t xnn_packed_stride_kai_qs8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, + size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; return kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, /*nr=*/1, @@ -1693,7 +1694,7 @@ size_t xnn_packed_stride_kai_qs8_weights_and_biases( void xnn_pack_kai_qs8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1715,7 +1716,7 @@ void xnn_pack_kai_qs8_weights_and_biases( const size_t n_stride = round_up(output_channels, nr); const size_t packed_weights_group_stride = n_stride * xnn_packed_stride_kai_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, + gemm_config, input_channels, unused_block_size, k_stride, extra_data0_element_size + extra_data1_element_size); if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { @@ -1762,7 +1763,7 @@ void xnn_pack_kai_qs8_weights_and_biases( } size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_block_size, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) { size_t ret_val = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(k) / @@ -1772,7 +1773,7 @@ size_t xnn_packed_stride_kai_f32_weights_and_biases( void xnn_pack_kai_f16_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, + size_t input_channels, size_t output_channels, size_t groups, size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index a15a2759365..b1692a05248 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -387,7 +387,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -403,7 +403,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); @@ -414,7 +414,7 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -430,7 +430,7 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); @@ -466,7 +466,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -482,7 +482,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qu8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); @@ -493,7 +493,7 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -509,7 +509,7 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // - size_t block_size, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); @@ -519,6 +519,7 @@ XNN_INTERNAL void xnn_pack_kai_qs8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -534,35 +535,60 @@ XNN_INTERNAL void xnn_pack_kai_qs8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_kai_qs8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); size_t xnn_packed_stride_kai_f16_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t unused_block_size, // + size_t unused_k_stride, // size_t extra_bytes); size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, - size_t unused_k_stride, size_t extra_bytes); + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t unused_block_size, // + size_t unused_k_stride, // + size_t extra_bytes); void xnn_pack_kai_f16_weights_and_biases( - uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, - xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, - size_t extra_data0_element_size, - xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, - size_t extra_data1_element_size, void* packed_weights_ptr, + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t unused_block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // const void* params); void xnn_pack_kai_f32_weights_and_biases( - uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, - size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, - xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, - size_t extra_data0_element_size, - xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, - size_t extra_data1_element_size, void* packed_weights_ptr, + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t unused_block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // const void* params); XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases( diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 79d9f3c8b45..a3051c6bf83 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -1796,7 +1796,7 @@ void GemmMicrokernelTester::Test( gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); const size_t packed_w_stride = - packed_stride(&gemm_config, k2, /*k_stride=*/k2, /*extra_bytes=*/0); + packed_stride(&gemm_config, k2, /*unused_block_size=*/0, /*k_stride=*/k2, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); xnnpack::Buffer packed_w(packed_w_size); @@ -1822,7 +1822,8 @@ void GemmMicrokernelTester::Test( params.input_zero_point = 1; params.kernel_zero_point = b_zero_point(); pack(/*flags=*/0, &gemm_config, k2, n(), - /*groups=*/1, /*k_stride=*/k2, + /*groups=*/1, /*unused_block_size=*/0, + /*k_stride=*/k2, /*accumulator_init=*/nullptr, /*weights=*/b.data(), /*int_extra_data0_fn=*/nullptr, @@ -1939,7 +1940,7 @@ void GemmMicrokernelTester::Test_QP8F32QC8W( gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); const size_t packed_w_stride = - packed_stride(&gemm_config, k(), /*k_stride=*/k(), /*extra_bytes=*/0); + packed_stride(&gemm_config, k(), /*unused_block_size=*/0, /*k_stride=*/k(), /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); xnnpack::Buffer packed_w(packed_w_size); @@ -1965,7 +1966,8 @@ void GemmMicrokernelTester::Test_QP8F32QC8W( params.input_zero_point = 1; params.scale_multiplier = 1.0f; pack(/*flags=*/0, &gemm_config, k(), n(), - /*groups=*/1, /*k_stride=*/k(), + /*groups=*/1, /*unused_block_size=*/0, + /*k_stride=*/k(), /*accumulator_init=*/nullptr, /*weights=*/b.data(), /*int_extra_data0_fn=*/nullptr, @@ -2085,7 +2087,7 @@ void GemmMicrokernelTester::Test( gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); const size_t packed_w_stride = - packed_stride(&gemm_config, k2, /*k_stride=*/bl(), /*extra_bytes=*/0); + packed_stride(&gemm_config, k2, /*block_size=*/bl(), /*k_stride=*/k2, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); xnnpack::Buffer packed_w(packed_w_size); @@ -2113,7 +2115,8 @@ void GemmMicrokernelTester::Test( params.input_zero_point = 1; params.kernel_zero_point = b_zero_point(); pack(/*flags=*/0, &gemm_config, k2, n(), - /*groups=*/1, /*k_stride=*/bl(), + /*groups=*/1, /*block_size=*/bl(), + /*k_stride=*/k2, /*accumulator_init=*/nullptr, /*weights=*/b.data(), /*int_extra_data0_fn=*/nullptr,