diff --git a/bench/gemm-benchmark.cc b/bench/gemm-benchmark.cc index d7eb7de4383..d8b152dc5de 100644 --- a/bench/gemm-benchmark.cc +++ b/bench/gemm-benchmark.cc @@ -735,7 +735,7 @@ void GEMMBenchmark(benchmark::State& state, gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr)); const size_t packed_w_stride = - packed_stride(&gemm_config, kc, /*k_stride=*/kc, /*extra_bytes=*/0); + packed_stride(&gemm_config, kc, /*unused_block_size=*/0, /*k_stride=*/kc, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(nc, nr); const size_t c_elements = mc * nc; @@ -760,7 +760,7 @@ void GEMMBenchmark(benchmark::State& state, const xnn_qs8_qc4w_packing_params packing_params = {/*input_zero_point=*/1, /*kernel_zero_point=*/8}; pack_weights(/*flags=*/0, &gemm_config, kc, nc, - /*groups=*/1, /*k_stride=*/kc, + /*groups=*/1, /*unused_block_size=*/0, /*k_stride=*/kc, /*accumulator_init=*/nullptr, /*weights=*/k.data(), /*int_extra_data0_fn=*/nullptr, @@ -852,7 +852,7 @@ void GEMMBenchmark(benchmark::State& state, gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr)); const size_t packed_w_stride = - packed_stride(&gemm_config, k2, /*k_stride=*/bl, /*extra_bytes=*/0); + packed_stride(&gemm_config, k2, /*block_size=*/bl, /*k_stride=*/kc, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(nc, nr); const size_t c_elements = mc * nc; @@ -879,7 +879,8 @@ void GEMMBenchmark(benchmark::State& state, const xnn_qs8_qc4w_packing_params packing_params = {/*input_zero_point=*/1, /*kernel_zero_point=*/8}; pack_weights(/*flags=*/0, &gemm_config, k2, nc, - /*groups=*/1, /*k_stride=*/bl, + /*groups=*/1, /*block_size=*/bl, + /*k_stride=*/kc, /*accumulator_init=*/nullptr, /*weights=*/k.data(), /*int_extra_data0_fn=*/nullptr, diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 78bd447a9b2..05be86767dd 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1539,7 +1539,8 @@ static void init_qd8_f16_qc4w_gemm_config(void) { } static void init_qd8_f16_qb4w_gemm_config(void) { - qd8_f16_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qd8_f16_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qd8_f16_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -1845,7 +1846,8 @@ static void init_qp8_f32_qb4w_gemm_config(void) { } static void init_qdu8_f32_qb4w_gemm_config(void) { - qdu8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qdu8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qdu8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); @@ -1884,7 +1886,8 @@ static void init_qdu8_f32_qb4w_gemm_config(void) { } static void init_qd8_f32_qb4w_gemm_config(void) { - qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + qd8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases; + qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases; #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index c19825035a6..f85338dcbe2 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -182,7 +182,9 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights( // Pack the weights. if (gemm_config->pack_weights_and_biases) { gemm_config->pack_weights_and_biases(flags, gemm_config, k, n, - /*groups=*/batch_size_b, k_stride, + /*groups=*/batch_size_b, + /*unused_block_size=*/0, + /*kstride=*/k_stride, /*accumulator_init=*/NULL, /*weights=*/data_b, /*int_extra_data0_fn=*/NULL, @@ -311,7 +313,7 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w( const size_t weights_stride = gemm_config->packed_stride_weights_and_biases ? gemm_config->packed_stride_weights_and_biases( - gemm_config, k, k_stride, extra_bytes) + gemm_config, k,/*unused_blocksize=*/0, k_stride, extra_bytes) : (k_stride << XNN_LOG2_SIZEOF_INT8_T) + extra_bytes + sizeof(int32_t); batch_matrix_multiply_op->weights_stride = weights_stride; @@ -347,7 +349,9 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w( batch_matrix_multiply_op->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS, gemm_config, /*input_channels=*/k, /*output_channels=*/n, - /*groups=*/batch_size_b, k_stride, + /*groups=*/batch_size_b, + /*unused_block_size=*/0, + /*k_stride=*/k_stride, /*accumulator_init=*/NULL, /*weights=*/data_b, /*int_extra_data0_fn=*/ diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index bfc0be24eb7..3f3a22b3720 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -372,6 +372,7 @@ static enum xnn_status create_gemm_or_igemm( gemm_config->pack_weights_and_biases( flags, gemm_config, group_input_channels, group_output_channels, groups, + /*unused_block_size*/0, k_stride, /*accumulator_init=*/bias, /*weights=*/kernel, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 528d58a31f9..6d4bf5fc960 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -44,7 +44,6 @@ static enum xnn_status create_fully_connected_nc( const void* bias, uint32_t flags, size_t block_size, - size_t extra_bl_bytes, const uint16_t* blockwise_kernel_scale_params, uint32_t log2_input_element_size, uint32_t log2_filter_element_size, @@ -52,7 +51,6 @@ static enum xnn_status create_fully_connected_nc( uint32_t bias_element_size, xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio_w, xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w, - xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi_bl_w, const void* packing_params, int packed_weights_padding_byte, size_t extra_weights_bytes, @@ -155,7 +153,7 @@ static enum xnn_status create_fully_connected_nc( const size_t weights_stride = gemm_config->packed_stride_weights_and_biases ? gemm_config->packed_stride_weights_and_biases( - gemm_config, input_channels, block_wise ? block_size : k_stride, extra_weights_bytes) + gemm_config, input_channels, block_size, k_stride, extra_weights_bytes) : (k_stride << log2_filter_element_size) + bias_element_size + extra_weights_bytes + block_scale_bytes; const size_t packed_weights_size = n_stride * weights_stride; @@ -192,7 +190,8 @@ static enum xnn_status create_fully_connected_nc( gemm_config->pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, /*groups=*/1, - block_wise ? block_size : k_stride, + /*block_wise=*/block_size, + /*kstride=*/k_stride, /*accumulator_init=*/bias, /*weights=*/kernel, /*int_extra_data0_fn=*/(xnn_init_scale_params_fn)init_scale_params, @@ -204,16 +203,6 @@ static enum xnn_status create_fully_connected_nc( /*extra_data1_size=*/init_kernel_scale_params != NULL ? sizeof(float) : 0, /*packed_weights_ptr=*/weights_ptr, packing_params); - - if (block_wise && bias != NULL) { - void* weights_start = (void*) ((uintptr_t) weights_ptr + - gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); - weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; - xnn_init_qs8_qc8w_scale_fp32_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, - bias, weights_start); - } } else { if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { pack_gemm_gio_w( @@ -225,24 +214,13 @@ static enum xnn_status create_fully_connected_nc( gemm_config->nr * extra_weights_bytes, packing_params); } else { - if (block_wise) { - pack_gemm_goi_bl_w( - /*groups=*/1, output_channels, input_channels, - nr, kr, sr, block_size, - kernel, /*bias=*/NULL, /*scale=*/blockwise_kernel_scale_params, - weights_ptr, - gemm_config->nr * extra_bl_bytes, - gemm_config->nr * extra_weights_bytes, - packing_params); - } else { - pack_gemm_goi_w( - /*groups=*/1, output_channels, input_channels, - nr, kr, sr, - kernel, bias, /*scale=*/NULL, - weights_ptr, - gemm_config->nr * extra_weights_bytes, - packing_params); - } + pack_gemm_goi_w( + /*groups=*/1, output_channels, input_channels, + nr, kr, sr, + kernel, bias, /*scale=*/NULL, + weights_ptr, + gemm_config->nr * extra_weights_bytes, + packing_params); } if (kernel_scale_params != NULL) { assert(init_kernel_scale_params != NULL); @@ -267,32 +245,6 @@ static enum xnn_status create_fully_connected_nc( gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, scale_params, weights); } - - if (block_wise) { - // Fill in kernel scale. - void* weights_start = (void*) ((uintptr_t) weights_ptr + - gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); - - const size_t block_stride = /*weights*/block_size / 2 + sizeof(uint16_t); - - xnn_init_blockwise_scale_bf16_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, - gemm_config->nr * weights_stride, - /*num_blocks=*/num_blocks, - /*block_stride=*/gemm_config->nr * block_stride, - 0, - (const xnn_bfloat16*)blockwise_kernel_scale_params, weights_start); - - // Fill in bias. - if (bias != NULL) { - weights_start = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; - xnn_init_qs8_qc8w_scale_fp32_params( - output_channels, gemm_config->nr, gemm_config->nr, - gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, - bias, weights_start); - } - } } if (use_weights_cache(fully_connected_op)) { @@ -397,7 +349,6 @@ enum xnn_status create_fully_connected_nc_f16( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_HALF, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_HALF, @@ -405,7 +356,6 @@ enum xnn_status create_fully_connected_nc_f16( /*bias_element_size=*/sizeof(uint16_t), pack_gemm_gio_w, pack_gemm_goi_w, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -533,7 +483,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc4w( input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -541,7 +490,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -708,7 +656,6 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -716,7 +663,6 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, /*packing_params=*/&packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -801,7 +747,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qc4w( input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -809,7 +754,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -921,7 +865,6 @@ static enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qcxw( input_channels, output_channels, input_stride, output_stride, kernel, /*bias=*/NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -929,7 +872,7 @@ static enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qcxw( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn)gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn)gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, packing_params, + packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, /*init_scale_params=*/NULL, @@ -1119,7 +1062,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -1127,7 +1069,6 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -1244,7 +1185,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qb4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/block_size, - /*extra_bl_bytes=*/sizeof(uint16_t), /*blockwise_kernel_scale_params=*/kernel_scale, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -1252,7 +1192,6 @@ enum xnn_status create_fully_connected_nc_qx8_f32_qb4w( /*bias_element_size=*/sizeof(float), /*pack_gemm_gio_w,=*/ NULL, /*pack_gemm_goi_w=*/ NULL, - /*pack_gemm_goi_bl_w=*/gemm_config->pack_gemm_goi_bl, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -1375,7 +1314,6 @@ enum xnn_status create_fully_connected_nc_qdx8_f32_qc8w( input_stride, output_stride, kernel, NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1383,7 +1321,6 @@ enum xnn_status create_fully_connected_nc_qdx8_f32_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -1506,7 +1443,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc8w( input_stride, output_stride, kernel, NULL, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1514,7 +1450,6 @@ enum xnn_status create_fully_connected_nc_qx8_f16_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float) * 2, @@ -1662,7 +1597,6 @@ enum xnn_status create_fully_connected_nc_f32( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_FLOAT, @@ -1670,7 +1604,6 @@ enum xnn_status create_fully_connected_nc_f32( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/0, @@ -1807,7 +1740,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, // Pass 1 byte even though it is half byte, we handle the division via filter_is_nibble == true. @@ -1816,7 +1748,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) NULL, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -1900,7 +1831,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc8w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -1908,7 +1838,6 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc8w( /*bias_element_size=*/sizeof(float), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, /*packing_params=*/NULL, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2000,7 +1929,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -2008,7 +1936,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2107,7 +2034,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -2115,7 +2041,6 @@ enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/0, /*extra_weights_bytes=*/sizeof(float), @@ -2207,7 +2132,6 @@ enum xnn_status xnn_create_fully_connected_nc_qu8( input_stride, output_stride, kernel, bias, flags, /*block_size=*/0, - /*extra_bl_bytes=*/0, /*blockwise_kernel_scale_params=*/NULL, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, @@ -2215,7 +2139,6 @@ enum xnn_status xnn_create_fully_connected_nc_qu8( /*bias_element_size=*/sizeof(int32_t), (xnn_packw_gemm_gio_ukernel_fn) gemm_config->pack_gemm_gio, (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, &packing_params, /*packed_weights_padding_byte=*/kernel_zero_point, /*extra_weights_bytes=*/0, diff --git a/src/reference/packing.cc b/src/reference/packing.cc index 60f64b94114..002beee3556 100644 --- a/src/reference/packing.cc +++ b/src/reference/packing.cc @@ -19,6 +19,7 @@ #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" +#include "xnnpack/microparams-init.h" #include "xnnpack/pack.h" #include "xnnpack/unaligned.h" @@ -1349,6 +1350,7 @@ void pack_weights_and_biases(uint32_t flags, // size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t weights_stride, // xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio_w, // xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w, // @@ -1413,8 +1415,8 @@ void pack_weights_and_biases(uint32_t flags, // } size_t xnn_packed_stride_qs8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1423,7 +1425,7 @@ size_t xnn_packed_stride_qs8_weights_and_biases( void xnn_pack_qs8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, + size_t input_channels, size_t output_channels, size_t groups, size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, @@ -1433,9 +1435,9 @@ void xnn_pack_qs8_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( - flags, gemm_config, input_channels, output_channels, groups, + flags, gemm_config, input_channels, output_channels, groups, unused_block_size, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_gemm_gio_w, (xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qs8_gemm_goi_w, accumulator_init, weights, init_extra_data0_fn, extra_data0, extra_data0_element_size, @@ -1444,8 +1446,8 @@ void xnn_pack_qs8_weights_and_biases( } size_t xnn_packed_stride_qs4_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1455,7 +1457,7 @@ size_t xnn_packed_stride_qs4_weights_and_biases( void xnn_pack_qs4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1464,10 +1466,10 @@ void xnn_pack_qs4_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( flags, gemm_config, input_channels, output_channels, groups, - weights_stride, + unused_block_size, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_qc4w_gemm_gio_w, (xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qs8_qc4w_gemm_goi_w, accumulator_init, weights, init_extra_data0_fn, extra_data0, @@ -1475,9 +1477,111 @@ void xnn_pack_qs4_weights_and_biases( extra_data1_element_size, packed_weights_ptr, extra_bytes, params); } +size_t xnn_packed_stride_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, + size_t k_stride, size_t extra_bytes) { + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const uint32_t nr = gemm_config->nr; + const size_t planes = gemm_config->planes; + + size_t input_channels = round_up_po2(k, planes); + + size_t block_scale_bytes = 0; + size_t num_blocks = 0; + const bool block_wise = (block_size != 0); + if (block_wise) { + num_blocks = input_channels / block_size; + block_scale_bytes += num_blocks * sizeof(uint16_t); + } + + const size_t bias_element_size = sizeof(int32_t); + const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; + return (k_stride << log2_filter_element_size) + bias_element_size + + extra_bytes + block_scale_bytes; +} + +void xnn_pack_qb4_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t block_size, size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const size_t planes = gemm_config->planes; + + const size_t extra_bytes_bl = sizeof(uint16_t); + const size_t extra_bytes_n = sizeof(uint32_t); + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + xnn_pack_qs8_qb4w_gemm_gio_w( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*k_stride=*/k_stride, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/NULL, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); + } else { + xnn_pack_qs8_qb4w_gemm_goi_w( + /*g=*/groups, + /*nc=*/output_channels, + /*kc=*/input_channels, + /*nr=*/nr, + /*kr=*/kr, + /*sr=*/sr, + /*bl=*/block_size, + /*k=*/(const uint8_t*)weights, + /*bias=*/NULL, + /*scale=*/(const xnn_bfloat16*)extra_data1, + /*packed_weights=*/packed_weights_ptr, + /*extra_bytes_bl=*/nr * extra_bytes_bl, + /*extra_bytes_n=*/nr * extra_bytes_n, + /*params*/(const struct xnn_qs8_qc4w_packing_params *)params); + } + + // fill in kernel scales + const size_t num_blocks = input_channels / block_size; + const size_t weights_stride = xnn_packed_stride_qb4_weights_and_biases(gemm_config, input_channels, block_size, k_stride, extra_bytes_n); + void* weights_start = (void*) ((uintptr_t) packed_weights_ptr + + nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); + + const size_t block_stride = /*weights*/block_size / 2 + sizeof(uint16_t); + xnn_init_blockwise_scale_bf16_params( + output_channels, nr, nr, + nr * weights_stride, + nr * weights_stride, + /*num_blocks=*/num_blocks, + /*block_stride=*/gemm_config->nr * block_stride, + 0, + (const xnn_bfloat16*)extra_data1, weights_start); + + // fill in bias if not null + if (accumulator_init != nullptr) { + weights_start = (void*) ((uintptr_t) packed_weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ; + xnn_init_qs8_qc8w_scale_fp32_params( + output_channels, gemm_config->nr, gemm_config->nr, + gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0, + (const float*)accumulator_init, weights_start); + } +} + size_t xnn_packed_stride_qu8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t unused_k, size_t unused_block_size, + size_t k_stride, size_t extra_bytes) { const size_t bias_element_size = sizeof(int32_t); const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T; return (k_stride << log2_filter_element_size) + bias_element_size + @@ -1487,7 +1591,7 @@ size_t xnn_packed_stride_qu8_weights_and_biases( void xnn_pack_qu8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1496,9 +1600,9 @@ void xnn_pack_qu8_weights_and_biases( const size_t extra_bytes = extra_data0_element_size + extra_data1_element_size; const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, extra_bytes); + gemm_config, input_channels, unused_block_size, k_stride, extra_bytes); return pack_weights_and_biases( - flags, gemm_config, input_channels, output_channels, groups, + flags, gemm_config, input_channels, output_channels, groups, unused_block_size, weights_stride, (xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qu8_gemm_gio_w, (xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qu8_gemm_goi_w, accumulator_init, weights, init_extra_data0_fn, extra_data0, extra_data0_element_size, @@ -1508,7 +1612,7 @@ void xnn_pack_qu8_weights_and_biases( #if XNN_ENABLE_KLEIDIAI size_t xnn_packed_stride_kai_qs4_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; @@ -1519,7 +1623,7 @@ size_t xnn_packed_stride_kai_qs4_weights_and_biases( void xnn_pack_kai_qs4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1561,8 +1665,8 @@ void xnn_pack_kai_qs4_weights_and_biases( } size_t xnn_packed_stride_kai_f16_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, + size_t unused_k_stride, size_t extra_bytes) { size_t ret_val = kai_get_rhs_packed_stride_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(k) / kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(); @@ -1579,8 +1683,8 @@ void transpose_weights_x16(const xnn_float16* in, xnn_float16* out, } size_t xnn_packed_stride_kai_qs8_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, - size_t extra_bytes) { + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, + size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; return kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, /*nr=*/1, @@ -1590,7 +1694,7 @@ size_t xnn_packed_stride_kai_qs8_weights_and_biases( void xnn_pack_kai_qs8_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1612,7 +1716,7 @@ void xnn_pack_kai_qs8_weights_and_biases( const size_t n_stride = round_up(output_channels, nr); const size_t packed_weights_group_stride = n_stride * xnn_packed_stride_kai_qs8_weights_and_biases( - gemm_config, input_channels, k_stride, + gemm_config, input_channels, unused_block_size, k_stride, extra_data0_element_size + extra_data1_element_size); if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { @@ -1659,7 +1763,7 @@ void xnn_pack_kai_qs8_weights_and_biases( } size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) { size_t ret_val = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(k) / @@ -1669,7 +1773,7 @@ size_t xnn_packed_stride_kai_f32_weights_and_biases( void xnn_pack_kai_f16_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, + size_t input_channels, size_t output_channels, size_t groups, size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, @@ -1732,7 +1836,7 @@ void transpose_weights(const float* in, float* out, size_t height, void xnn_pack_kai_f32_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, + size_t unused_block_size, size_t k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1784,7 +1888,7 @@ void xnn_pack_kai_f32_weights_and_biases( size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size, - size_t extra_bytes) { + size_t unused_k_stride, size_t extra_bytes) { const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; const uint32_t nr = gemm_config->nr; @@ -1802,7 +1906,7 @@ size_t xnn_packed_stride_kai_qb4_weights_and_biases( void xnn_pack_kai_qb4_weights_and_biases( uint32_t flags, const struct xnn_gemm_config* gemm_config, size_t input_channels, size_t output_channels, size_t groups, - size_t block_size, const void* accumulator_init, const void* weights, + size_t block_size, size_t unused_k_stride, const void* accumulator_init, const void* weights, xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, size_t extra_data0_element_size, xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, @@ -1850,6 +1954,19 @@ void xnn_pack_kai_qb4_weights_and_biases( /*rhs_packed*/ packed_weights_ptr, /*extra_bytes=*/0, &kai_params); } + + // init bias + const size_t weights_stride = xnn_packed_stride_kai_qb4_weights_and_biases( + gemm_config, input_channels, block_size, unused_k_stride, 0); + if (accumulator_init != NULL) { + void* weights_start = (void*) ((uintptr_t) packed_weights_ptr + + nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2))); + weights_start = (void*) ((uintptr_t) packed_weights_ptr + nr * (weights_stride - sizeof(float))) ; + xnn_init_qs8_qc8w_scale_fp32_params( + output_channels, nr, nr, + nr * weights_stride, nr * weights_stride, 0, + (const float*)accumulator_init, weights_start); + } } #endif // XNN_ENABLE_KLEIDIAI diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 94bf319b022..48762f6fc39 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -2244,6 +2244,7 @@ typedef void (*xnn_pack_weights_and_biases_fn)( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t block_size, // // We tile packing by output channels, in GIO layout, the k (row) index // needs to be able to skip by the actual number of output channels, and not // just the argument nc. E.g. if weights is 1x3x5, and nr is 2, we tile the @@ -2270,6 +2271,7 @@ typedef void (*xnn_pack_weights_and_biases_fn)( typedef size_t (*xnn_packed_stride_weights_and_biases_fn)( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index 6791f846f48..b1692a05248 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -387,6 +387,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -402,6 +403,7 @@ XNN_INTERNAL void xnn_pack_qs8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); @@ -412,6 +414,7 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -427,6 +430,33 @@ XNN_INTERNAL void xnn_pack_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t unused_block_size, // + size_t k_stride, // + size_t extra_bytes); + +XNN_INTERNAL void xnn_pack_qb4_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + +XNN_INTERNAL size_t xnn_packed_stride_qb4_weights_and_biases( + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t block_size, // size_t k_stride, // size_t extra_bytes); @@ -436,6 +466,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -451,6 +482,7 @@ XNN_INTERNAL void xnn_pack_qu8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_qu8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); @@ -461,6 +493,7 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -476,6 +509,7 @@ XNN_INTERNAL void xnn_pack_kai_qs4_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); @@ -485,6 +519,7 @@ XNN_INTERNAL void xnn_pack_kai_qs8_weights_and_biases( size_t input_channels, // size_t output_channels, // size_t groups, // + size_t unused_block_size, // size_t k_stride, // const void* accumulator_init, // const void* weights, // @@ -500,35 +535,60 @@ XNN_INTERNAL void xnn_pack_kai_qs8_weights_and_biases( XNN_INTERNAL size_t xnn_packed_stride_kai_qs8_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // + size_t unused_block_size, // size_t k_stride, // size_t extra_bytes); size_t xnn_packed_stride_kai_f16_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t unused_block_size, // + size_t unused_k_stride, // size_t extra_bytes); size_t xnn_packed_stride_kai_f32_weights_and_biases( - const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t unused_block_size, // + size_t unused_k_stride, // size_t extra_bytes); void xnn_pack_kai_f16_weights_and_biases( - uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, - xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, - size_t extra_data0_element_size, - xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, - size_t extra_data1_element_size, void* packed_weights_ptr, + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t unused_block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // const void* params); void xnn_pack_kai_f32_weights_and_biases( - uint32_t flags, const struct xnn_gemm_config* gemm_config, - size_t input_channels, size_t output_channels, size_t groups, - size_t k_stride, const void* accumulator_init, const void* weights, - xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, - size_t extra_data0_element_size, - xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, - size_t extra_data1_element_size, void* packed_weights_ptr, + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t unused_block_size, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // const void* params); XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases( @@ -538,6 +598,7 @@ XNN_INTERNAL void xnn_pack_kai_qb4_weights_and_biases( size_t output_channels, // size_t groups, // size_t block_size, // + size_t k_stride, // const void* accumulator_init, // const void* weights, // xnn_init_scale_params_fn init_extra_data0_fn, // @@ -553,6 +614,7 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qb4_weights_and_biases( const struct xnn_gemm_config* gemm_config, // size_t k, // size_t block_size, // + size_t k_stride, // size_t extra_bytes); #endif // XNN_ENABLE_KLEIDIAI diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 79d9f3c8b45..a3051c6bf83 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -1796,7 +1796,7 @@ void GemmMicrokernelTester::Test( gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); const size_t packed_w_stride = - packed_stride(&gemm_config, k2, /*k_stride=*/k2, /*extra_bytes=*/0); + packed_stride(&gemm_config, k2, /*unused_block_size=*/0, /*k_stride=*/k2, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); xnnpack::Buffer packed_w(packed_w_size); @@ -1822,7 +1822,8 @@ void GemmMicrokernelTester::Test( params.input_zero_point = 1; params.kernel_zero_point = b_zero_point(); pack(/*flags=*/0, &gemm_config, k2, n(), - /*groups=*/1, /*k_stride=*/k2, + /*groups=*/1, /*unused_block_size=*/0, + /*k_stride=*/k2, /*accumulator_init=*/nullptr, /*weights=*/b.data(), /*int_extra_data0_fn=*/nullptr, @@ -1939,7 +1940,7 @@ void GemmMicrokernelTester::Test_QP8F32QC8W( gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); const size_t packed_w_stride = - packed_stride(&gemm_config, k(), /*k_stride=*/k(), /*extra_bytes=*/0); + packed_stride(&gemm_config, k(), /*unused_block_size=*/0, /*k_stride=*/k(), /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); xnnpack::Buffer packed_w(packed_w_size); @@ -1965,7 +1966,8 @@ void GemmMicrokernelTester::Test_QP8F32QC8W( params.input_zero_point = 1; params.scale_multiplier = 1.0f; pack(/*flags=*/0, &gemm_config, k(), n(), - /*groups=*/1, /*k_stride=*/k(), + /*groups=*/1, /*unused_block_size=*/0, + /*k_stride=*/k(), /*accumulator_init=*/nullptr, /*weights=*/b.data(), /*int_extra_data0_fn=*/nullptr, @@ -2085,7 +2087,7 @@ void GemmMicrokernelTester::Test( gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); const size_t packed_w_stride = - packed_stride(&gemm_config, k2, /*k_stride=*/bl(), /*extra_bytes=*/0); + packed_stride(&gemm_config, k2, /*block_size=*/bl(), /*k_stride=*/k2, /*extra_bytes=*/0); const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); xnnpack::Buffer packed_w(packed_w_size); @@ -2113,7 +2115,8 @@ void GemmMicrokernelTester::Test( params.input_zero_point = 1; params.kernel_zero_point = b_zero_point(); pack(/*flags=*/0, &gemm_config, k2, n(), - /*groups=*/1, /*k_stride=*/bl(), + /*groups=*/1, /*block_size=*/bl(), + /*k_stride=*/k2, /*accumulator_init=*/nullptr, /*weights=*/b.data(), /*int_extra_data0_fn=*/nullptr,