Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Multithread packing #7545

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,10 @@ xnnpack_cc_library(

xnnpack_cxx_library(
name = "packing",
srcs = ["src/reference/packing.cc"],
srcs = [
"src/reference/packing.cc",
"src/packw.c"
],
hdrs = ["src/xnnpack/pack.h"],
defines = xnnpack_configurable_defines(),
deps = [
Expand Down
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ ADD_LIBRARY(indirection OBJECT src/indirection.c)
ADD_LIBRARY(logging OBJECT ${LOGGING_SRCS})
ADD_LIBRARY(microparams-init OBJECT src/microparams-init.c)
ADD_LIBRARY(normalization OBJECT src/normalization.c)
ADD_LIBRARY(packing OBJECT src/reference/packing.cc)
ADD_LIBRARY(packing OBJECT src/reference/packing.cc src/packw.c)
TARGET_LINK_LIBRARIES(hardware-config PRIVATE xnnpack-base logging)
TARGET_LINK_LIBRARIES(indirection PRIVATE xnnpack-base)
TARGET_LINK_LIBRARIES(logging PRIVATE xnnpack-base)
Expand Down Expand Up @@ -1469,6 +1469,7 @@ IF(XNNPACK_BUILD_TESTS)
x8-packw
qs8-packw
qs8-qc4w-packw
qb4-packw
x8-zip
xN-transpose
xx-fill
Expand Down Expand Up @@ -1946,6 +1947,7 @@ IF(XNNPACK_BUILD_BENCHMARKS)
qu8-gemm
qu8-gemm-fp32
qu8-gemm-rndnu
qb4-packw
x16-packw
x32-packw
x8-lut
Expand Down
67 changes: 67 additions & 0 deletions bench/packw-benchmark.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,73 @@ static void x8_gio_packw(benchmark::State& state,
benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
}

static void qb4_packw(benchmark::State& state,
xnn_qb4_packw_gemm_goi_ukernel_fn packw,
size_t nr, size_t kr, size_t sr, size_t bl,
benchmark::utils::IsaCheckFunction isa_check = nullptr)
{
if (isa_check != nullptr && !isa_check(state)) {
return;
}

const size_t batch = 1; // batch is g parameter for packw
const size_t dim_n = state.range(2); // dim_n is nc parameter
const size_t dim_k = state.range(3); // dim_k is kc parameter

const size_t rounded_n = benchmark::utils::RoundUp(dim_n, nr);
const size_t rounded_k = benchmark::utils::RoundUp(dim_k, bl);

std::random_device random_device;
auto rng = std::mt19937(random_device());

// Computer num_buffers that fit cache with source weights + packed_weights.
const size_t num_buffers = 1 +
benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
batch * (dim_n * dim_k + rounded_n * rounded_k + rounded_n * sizeof(uint32_t)));

xnnpack::Buffer<uint8_t, XNN_ALLOCATION_ALIGNMENT> weights(num_buffers * batch *
dim_n * (rounded_k >> 1));
xnnpack::fill_uniform_random_bits(weights.data(), weights.size(), rng);
xnnpack::Buffer<int8_t, XNN_ALLOCATION_ALIGNMENT> packed_weights(
num_buffers * batch *
(rounded_n * (rounded_k >> 1) + rounded_n * sizeof(uint32_t)));
xnnpack::Buffer<int32_t, XNN_ALLOCATION_ALIGNMENT> bias(num_buffers * batch * dim_n);
xnnpack::fill_uniform_random_bits(bias.data(), bias.size(), rng);
size_t num_blocks = rounded_k / bl;
xnnpack::Buffer<xnn_bfloat16, XNN_ALLOCATION_ALIGNMENT> bf16_scales(num_blocks * batch * dim_n);
xnnpack::fill_uniform_random_bits(bf16_scales.data(), bf16_scales.size(), rng);

const xnn_qs8_qc4w_packing_params packing_params = { 1, 8 };

size_t buffer_index = 0;
for (auto _ : state) {
if (++buffer_index == num_buffers) {
buffer_index = 0;
}

packw(1, dim_n, rounded_k, nr, kr, sr, bl,
weights.data() + buffer_index * batch * dim_n * (rounded_k >> 1),
/*bias=*/bias.data() + buffer_index * batch * dim_n,
/*scale=*/bf16_scales.data() + buffer_index * batch * dim_n,
packed_weights.data() + buffer_index * batch * (rounded_n * (rounded_k >> 1) + rounded_n * sizeof(uint32_t) + rounded_n * sizeof(uint16_t)),
/*extra_bytes_bl=*/sizeof(uint16_t) * nr, sizeof(float), &packing_params);
}

const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
if (cpu_frequency != 0) {
state.counters["cpufreq"] = cpu_frequency;
}

const size_t elements_per_iteration = batch * dim_n * (rounded_k >> 1);
state.counters["elements"] =
benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);

const size_t bytes_per_iteration = (elements_per_iteration + batch * (rounded_n * rounded_k + rounded_n * sizeof(uint32_t)));
state.counters["bytes"] =
benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
}


static void qs8_packw(benchmark::State& state,
xnn_qs8_packw_gemm_goi_ukernel_fn packw,
size_t nr, size_t kr, size_t sr,
Expand Down
32 changes: 32 additions & 0 deletions bench/qb4-packw.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.


#include <benchmark/benchmark.h>
#include "bgemm.h"
#include "packw-benchmark.h"
#include "utils.h"
#include "xnnpack/common.h"
#include "xnnpack/hardware-config.h"
#include "xnnpack/packw.h"

static void qb4_packw(benchmark::State& state, const char* net,
xnn_qb4_packw_gemm_goi_ukernel_fn ukernel,
uint64_t arch_flags, size_t nr, size_t kr, size_t sr, size_t bl) {
benchmark::utils::CheckArchFlags(state, arch_flags);
qb4_packw(state, ukernel, nr, kr, sr, bl);
}

#define XNN_QB4_UKERNEL(arch_flags, ukernel, nr, kr, sr, bl, kblock, nr_scale, izp) \
BENCHMARK_CAPTURE_BGEMM(qb4_packw, ukernel##_, ukernel, arch_flags, nr, kr, sr, bl);

#include "qb4-packw/qb4-packw.h"

#undef XNN_QB4_UKERNEL


#ifndef XNNPACK_BENCHMARK_NO_MAIN
BENCHMARK_MAIN();
#endif
1 change: 0 additions & 1 deletion bench/qs8-packw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,3 @@ BENCHMARK_CAPTURE_BGEMM(qs8_gio_packw, ukernel##_, ukernel, arch_flags, nr, kr,
#ifndef XNNPACK_BENCHMARK_NO_MAIN
BENCHMARK_MAIN();
#endif

6 changes: 4 additions & 2 deletions cmake/gen/scalar_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS
src/f32-vunary/gen/f32-vabs-scalar.c
src/f32-vunary/gen/f32-vneg-scalar.c
src/f32-vunary/gen/f32-vsqr-scalar.c
src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c
src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c
src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c
src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c
Expand All @@ -157,8 +159,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u4.c
src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c
src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c
src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c
src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-lrintf.c
Expand Down Expand Up @@ -621,6 +621,8 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS
src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c
src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c
src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c
src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c
src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c
src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x32c8-gemm-goi-scalar.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c
Expand Down
6 changes: 4 additions & 2 deletions gen/scalar_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ PROD_SCALAR_MICROKERNEL_SRCS = [
"src/f32-vunary/gen/f32-vabs-scalar.c",
"src/f32-vunary/gen/f32-vneg-scalar.c",
"src/f32-vunary/gen/f32-vsqr-scalar.c",
"src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c",
"src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c",
"src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4-minmax-scalar.c",
"src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-4x4-minmax-scalar.c",
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4-minmax-scalar.c",
Expand All @@ -153,8 +155,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u4.c",
"src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c",
"src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c",
"src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c",
"src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-lrintf.c",
Expand Down Expand Up @@ -618,6 +618,8 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [
"src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c",
"src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c",
"src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c",
"src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c",
"src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c",
"src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x32c8-gemm-goi-scalar.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c",
Expand Down
6 changes: 4 additions & 2 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -3925,7 +3925,8 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w(
uint32_t flags,
xnn_code_cache_t code_cache,
xnn_weights_cache_t weights_cache,
xnn_operator_t* fully_connected_op_out);
xnn_operator_t* fully_connected_op_out,
pthreadpool_t threadpool);

enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qb4w(
xnn_operator_t fully_connected_op,
Expand Down Expand Up @@ -3980,7 +3981,8 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w(
uint32_t flags,
xnn_code_cache_t code_cache,
xnn_weights_cache_t weights_cache,
xnn_operator_t* fully_connected_op_out);
xnn_operator_t* fully_connected_op_out,
pthreadpool_t threadpool);

enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qb4w(
xnn_operator_t fully_connected_op,
Expand Down
2 changes: 1 addition & 1 deletion scripts/build-local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mkdir -p build/local
CMAKE_ARGS=()

# CMake-level configuration
CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release")
CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Debug")
CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON")

# If Ninja is installed, prefer it to Make
Expand Down
5 changes: 5 additions & 0 deletions scripts/generate-qb4-packw.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# C8 Packing
tools/xngen src/qb4-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D -o src/qb4-packw/gen/qb4-packw-x16c8-gemm-goi-scalar.c

# C4 Packing
tools/xngen src/qb4-packw/kr-scalar.c.in -D NR=16 -D KR=4 -D -o src/qb4-packw/gen/qb4-packw-x16c4-gemm-goi-scalar.c
12 changes: 9 additions & 3 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,8 @@ static void init_qd8_f16_qc4w_gemm_config(void) {
}

static void init_qd8_f16_qb4w_gemm_config(void) {
qd8_f16_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w;
qd8_f16_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases;
qd8_f16_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases;

#if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
Expand Down Expand Up @@ -1744,7 +1745,8 @@ static void init_qp8_f32_qb4w_gemm_config(void) {
}

static void init_qdu8_f32_qb4w_gemm_config(void) {
qdu8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w;
qdu8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases;
qdu8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases;
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
Expand Down Expand Up @@ -1777,7 +1779,8 @@ static void init_qdu8_f32_qb4w_gemm_config(void) {
}

static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w;
qd8_f32_qb4w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_qb4_weights_and_biases;
qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_weights_and_biases;

#if XNN_ARCH_ARM
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
Expand All @@ -1792,6 +1795,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.nr = 16;
qd8_f32_qb4w_gemm_config.log2_kr = 2;
qd8_f32_qb4w_gemm_config.planes = 2;
qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases;
#endif // XNN_ENABLE_ARM_DOTPROD
} else {
qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane);
Expand Down Expand Up @@ -1821,6 +1825,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.nr = 16;
qd8_f32_qb4w_gemm_config.log2_kr = 3;
qd8_f32_qb4w_gemm_config.planes = 2;
qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c8_weights_and_biases;
#endif // XNN_ENABLE_ARM_I8MM
} else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) {
#if XNN_ENABLE_ARM_DOTPROD
Expand All @@ -1831,6 +1836,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) {
qd8_f32_qb4w_gemm_config.nr = 16;
qd8_f32_qb4w_gemm_config.log2_kr = 2;
qd8_f32_qb4w_gemm_config.planes = 2;
qd8_f32_qb4w_gemm_config.pack_weights_and_biases = xnn_pack_qb4_x16c4_weights_and_biases;
#endif // XNN_ENABLE_ARM_DOTPROD
} else {
qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane);
Expand Down
17 changes: 12 additions & 5 deletions src/operators/batch-matrix-multiply-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights(
// Pack the weights.
if (gemm_config->pack_weights_and_biases) {
gemm_config->pack_weights_and_biases(flags, gemm_config, k, n,
/*groups=*/batch_size_b, k_stride,
/*groups=*/batch_size_b,
/*unused_block_size=*/0,
/*kstride=*/k_stride,
/*accumulator_init=*/NULL,
/*weights=*/data_b,
/*int_extra_data0_fn=*/NULL,
Expand All @@ -191,7 +193,8 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights(
/*extra_data1=*/NULL,
/*extra_data1_size=*/0,
/*packed_weights_ptr=*/packed_data,
/*packing_params=*/NULL);
/*packing_params=*/NULL,
/*pthreadpool=*/NULL);
} else {
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
batch_matrix_multiply_op->ukernel.gemm.packw_gemm_goi(
Expand Down Expand Up @@ -313,7 +316,7 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w(
const size_t weights_stride =
gemm_config->packed_stride_weights_and_biases
? gemm_config->packed_stride_weights_and_biases(
gemm_config, k, k_stride, extra_bytes)
gemm_config, k,/*unused_blocksize=*/0, k_stride, extra_bytes)
: (k_stride << XNN_LOG2_SIZEOF_INT8_T) + extra_bytes +
sizeof(int32_t);
assert(weights_stride == (k_stride << XNN_LOG2_SIZEOF_INT8_T) +
Expand Down Expand Up @@ -345,7 +348,9 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w(
batch_matrix_multiply_op->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS,
gemm_config, /*input_channels=*/k,
/*output_channels=*/n,
/*groups=*/batch_size_b, k_stride,
/*groups=*/batch_size_b,
/*unused_block_size=*/0,
/*k_stride=*/k_stride,
/*accumulator_init=*/NULL,
/*weights=*/data_b,
/*int_extra_data0_fn=*/
Expand All @@ -356,7 +361,9 @@ enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w(
(xnn_init_scale_params_fn)xnn_init_qs8_qc8w_scale_fp32_params,
/*extra_data1=*/scale_b,
/*extra_data1_size=*/sizeof(float),
/*packed_weights_ptr=*/packed_data, &pack_gemm_params);
/*packed_weights_ptr=*/packed_data,
/*params=*/&pack_gemm_params,
/*pthreadpool=*/NULL);
} else {
if (batch_matrix_multiply_op->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
batch_matrix_multiply_op->ukernel.gemm.packw_gemm_goi(
Expand Down
5 changes: 4 additions & 1 deletion src/operators/convolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ static enum xnn_status create_gemm_or_igemm(
gemm_config->pack_weights_and_biases(
flags, gemm_config, group_input_channels, group_output_channels,
groups,
/*unused_block_size*/0,
k_stride,
/*accumulator_init=*/bias,
/*weights=*/kernel,
Expand All @@ -383,7 +384,9 @@ static enum xnn_status create_gemm_or_igemm(
/*extra_data1=*/(const void *) kernel_scale_params,
/*extra_data1_size=*/init_kernel_scale_params != NULL ? sizeof(float)
: 0,
/*packed_weights_ptr=*/weights_ptr, packing_params);
/*packed_weights_ptr=*/weights_ptr,
/*params=*/packing_params,
/*pthreadpool=*/NULL);
// Kernel and bias have already been packed so prevent them from being
// packed again below.
weights_already_cached = true;
Expand Down
Loading