Skip to content

Commit

Permalink
Merge pull request #7479 from xujuntwt95329:x8_gio_pack
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698961875
  • Loading branch information
xnnpack-bot committed Nov 22, 2024
2 parents 0270f91 + 916717b commit ee5baf3
Show file tree
Hide file tree
Showing 23 changed files with 17,448 additions and 0 deletions.
218 changes: 218 additions & 0 deletions bench/packw-benchmark.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,66 @@ static void x8_packw(benchmark::State& state,
benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
}

static void x8_gio_packw(benchmark::State& state,
xnn_x8_packw_gemm_gio_ukernel_fn packw,
size_t nr, size_t kr, size_t sr,
benchmark::utils::IsaCheckFunction isa_check = nullptr)
{
if (isa_check != nullptr && !isa_check(state)) {
return;
}

const size_t batch = state.range(0); // batch is g parameter for packw
const size_t dim_n = state.range(2); // dim_n is nc parameter
const size_t dim_k = state.range(3); // dim_k is kc parameter

const size_t rounded_n = benchmark::utils::RoundUp(dim_n, nr);
const size_t rounded_k = benchmark::utils::RoundUp(dim_k, kr * sr);

std::random_device random_device;
auto rng = std::mt19937(random_device());

// Computer num_buffers that fit cache with source weights + packed_weights.
const size_t num_buffers = 1 +
benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
sizeof(int8_t) * batch * (dim_n * dim_k + rounded_n * rounded_k + rounded_n));

xnnpack::Buffer<int8_t, XNN_ALLOCATION_ALIGNMENT> weights(num_buffers * batch *
dim_n * dim_k);
xnnpack::fill_uniform_random_bits(weights.data(), weights.size(), rng);
xnnpack::Buffer<int8_t, XNN_ALLOCATION_ALIGNMENT> packed_weights(
num_buffers * batch *
(rounded_n * rounded_k + rounded_n * sizeof(uint32_t)));

const xnn_qs8_packw_params params = {127};

size_t buffer_index = 0;
for (auto _ : state) {
if (++buffer_index == num_buffers) {
buffer_index = 0;
}

packw(batch, dim_n, dim_k, nr, kr, sr, dim_n /* k_stride */,
weights.data() + buffer_index * batch * dim_n * dim_k,
/*bias=*/nullptr, /*scale=*/nullptr,
packed_weights.data() + buffer_index * batch * (rounded_n * rounded_k + rounded_n),
/*extra_bytes=*/0, &params);
}

const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
if (cpu_frequency != 0) {
state.counters["cpufreq"] = cpu_frequency;
}

const size_t elements_per_iteration = batch * dim_n * dim_k;
state.counters["elements"] =
benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);

const size_t bytes_per_iteration = (elements_per_iteration + batch * (rounded_n * rounded_k + rounded_n)) * sizeof(int8_t);
state.counters["bytes"] =
benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
}

static void qs8_packw(benchmark::State& state,
xnn_qs8_packw_gemm_goi_ukernel_fn packw,
size_t nr, size_t kr, size_t sr,
Expand Down Expand Up @@ -136,6 +196,66 @@ static void qs8_packw(benchmark::State& state,
benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
}

static void qs8_gio_packw(benchmark::State& state,
xnn_qs8_packw_gemm_gio_ukernel_fn packw,
size_t nr, size_t kr, size_t sr,
benchmark::utils::IsaCheckFunction isa_check = nullptr)
{
if (isa_check != nullptr && !isa_check(state)) {
return;
}

const size_t batch = state.range(0); // batch is g parameter for packw
const size_t dim_n = state.range(2); // dim_n is nc parameter
const size_t dim_k = state.range(3); // dim_k is kc parameter

const size_t rounded_n = benchmark::utils::RoundUp(dim_n, nr);
const size_t rounded_k = benchmark::utils::RoundUp(dim_k, kr * sr);

std::random_device random_device;
auto rng = std::mt19937(random_device());

// Computer num_buffers that fit cache with source weights + packed_weights.
const size_t num_buffers = 1 +
benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
sizeof(int8_t) * batch * (dim_n * dim_k + rounded_n * rounded_k + rounded_n));

xnnpack::Buffer<int8_t, XNN_ALLOCATION_ALIGNMENT> weights(num_buffers * batch *
dim_n * dim_k);
xnnpack::fill_uniform_random_bits(weights.data(), weights.size(), rng);
xnnpack::Buffer<int8_t, XNN_ALLOCATION_ALIGNMENT> packed_weights(
num_buffers * batch *
(rounded_n * rounded_k + rounded_n * sizeof(uint32_t)));

const xnn_qs8_packw_params params = {127};

size_t buffer_index = 0;
for (auto _ : state) {
if (++buffer_index == num_buffers) {
buffer_index = 0;
}

packw(batch, dim_n, dim_k, nr, kr, sr, dim_n,
weights.data() + buffer_index * batch * dim_n * dim_k,
/*bias=*/nullptr, /*scale=*/nullptr,
packed_weights.data() + buffer_index * batch * (rounded_n * rounded_k + rounded_n),
/*extra_bytes=*/0, &params);
}

const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
if (cpu_frequency != 0) {
state.counters["cpufreq"] = cpu_frequency;
}

const size_t elements_per_iteration = batch * dim_n * dim_k;
state.counters["elements"] =
benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);

const size_t bytes_per_iteration = (elements_per_iteration + batch * (rounded_n * rounded_k + rounded_n)) * sizeof(int8_t);
state.counters["bytes"] =
benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
}

static void x16_packw(benchmark::State& state,
xnn_x16_packw_gemm_goi_ukernel_fn packw,
size_t nr, size_t kr, size_t sr,
Expand Down Expand Up @@ -365,6 +485,67 @@ BENCHMARK_BGEMM(x8_packw_x8__reference)
BENCHMARK_BGEMM(x8_packw_x16__reference)
BENCHMARK_BGEMM(x8_packw_x32__reference)

static void x8_packw_gio__reference(
size_t batch,
size_t dim_n,
size_t dim_k,
size_t nr,
size_t kr,
size_t sr,
const int8_t* weights,
const uint32_t* bias,
const void* scale,
int8_t* packed_weights,
size_t extra_bytes,
const void* params)
{
xnn_pack_f32_qs8w_gemm_gio_w(batch, dim_n, dim_k, nr, kr, sr, dim_n,
reinterpret_cast<const int8_t*>(weights),
reinterpret_cast<const float*>(bias),
static_cast<const float*>(scale),
static_cast<void*>(packed_weights),
extra_bytes, params);
}

static void x8_packw_gio_x2__reference(benchmark::State& state, const char* net) {
x8_packw(state,
x8_packw_gio__reference,
/*nr=*/2, /*kr=*/1, /*sr=*/1);
}
static void x8_packw_gio_x4__reference(benchmark::State& state, const char* net) {
x8_packw(state,
x8_packw_gio__reference,
/*nr=*/4, /*kr=*/1, /*sr=*/1);
}
static void x8_packw_gio_x8__reference(benchmark::State& state, const char* net) {
x8_packw(state,
x8_packw_gio__reference,
/*nr=*/8, /*kr=*/1, /*sr=*/1);
}
static void x8_packw_gio_x16__reference(benchmark::State& state, const char* net) {
x8_packw(state,
x8_packw_gio__reference,
/*nr=*/16, /*kr=*/1, /*sr=*/1);
}
static void x8_packw_gio_x32__reference(benchmark::State& state, const char* net) {
x8_packw(state,
x8_packw_gio__reference,
/*nr=*/32, /*kr=*/1, /*sr=*/1);
}

static void x8_packw_gio_x8c8__reference(benchmark::State& state, const char* net) {
x8_packw(state,
x8_packw_gio__reference,
/*nr=*/8, /*kr=*/8, /*sr=*/1);
}

BENCHMARK_BGEMM(x8_packw_gio_x2__reference)
BENCHMARK_BGEMM(x8_packw_gio_x4__reference)
BENCHMARK_BGEMM(x8_packw_gio_x8__reference)
BENCHMARK_BGEMM(x8_packw_gio_x16__reference)
BENCHMARK_BGEMM(x8_packw_gio_x32__reference)
BENCHMARK_BGEMM(x8_packw_gio_x8c8__reference)

static void qs8_packw__reference(
size_t batch,
size_t dim_n,
Expand Down Expand Up @@ -428,6 +609,43 @@ static void qs8_packw_x16c8__reference(benchmark::State& state, const char* net)
BENCHMARK_BGEMM(qs8_packw_x8c8__reference)
BENCHMARK_BGEMM(qs8_packw_x16c8__reference)

static void qs8_packw_gio__reference(
size_t batch,
size_t dim_n,
size_t dim_k,
size_t nr,
size_t kr,
size_t sr,
const int8_t* weights,
const int32_t* bias,
const void* scale,
int8_t* packed_weights,
size_t extra_bytes,
const void* params)
{
xnn_pack_qs8_gemm_gio_w(batch, dim_n, dim_k, nr, kr, sr, dim_n,
reinterpret_cast<const int8_t*>(weights),
reinterpret_cast<const int32_t*>(bias),
static_cast<const float*>(scale),
static_cast<void*>(packed_weights),
extra_bytes,
reinterpret_cast<const struct xnn_qs8_packing_params*>(params));
}

static void qs8_packw_gio_x8c8__reference(benchmark::State& state, const char* net) {
qs8_packw(state,
qs8_packw_gio__reference,
/*nr=*/8, /*kr=*/8, /*sr=*/1);
}
static void qs8_packw_gio_x16c8__reference(benchmark::State& state, const char* net) {
qs8_packw(state,
qs8_packw_gio__reference,
/*nr=*/16, /*kr=*/8, /*sr=*/1);
}

BENCHMARK_BGEMM(qs8_packw_gio_x8c8__reference)
BENCHMARK_BGEMM(qs8_packw_gio_x16c8__reference)

static void x16_packw__reference(
size_t batch,
size_t dim_n,
Expand Down
10 changes: 10 additions & 0 deletions bench/qs8-packw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@ static void qs8_packw(benchmark::State& state, const char* net,
qs8_packw(state, ukernel, nr, kr, sr);
}

static void qs8_gio_packw(benchmark::State& state, const char* net,
xnn_qs8_packw_gemm_gio_ukernel_fn ukernel,
uint64_t arch_flags, size_t nr, size_t kr, size_t sr) {
benchmark::utils::CheckArchFlags(state, arch_flags);
qs8_gio_packw(state, ukernel, nr, kr, sr);
}

#define XNN_QS8_UKERNEL(arch_flags, ukernel, nr, kr, sr, kblock, nr_scale, izp) \
BENCHMARK_CAPTURE_BGEMM(qs8_packw, ukernel##_, ukernel, arch_flags, nr, kr, sr);

#define XNN_QS8_GIO_UKERNEL(arch_flags, ukernel, nr, kr, sr, kblock, nr_scale, izp) \
BENCHMARK_CAPTURE_BGEMM(qs8_gio_packw, ukernel##_, ukernel, arch_flags, nr, kr, sr);

#include "qs8-packw/qs8-packw.h"

#undef XNN_QS8_UKERNEL
Expand Down
10 changes: 10 additions & 0 deletions bench/x8-packw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@ static void x8_packw(benchmark::State& state, const char* net,
x8_packw(state, ukernel, nr, kr, sr);
}

static void x8_gio_packw(benchmark::State& state, const char* net,
xnn_x8_packw_gemm_gio_ukernel_fn ukernel,
uint64_t arch_flags, size_t nr, size_t kr, size_t sr) {
benchmark::utils::CheckArchFlags(state, arch_flags);
x8_gio_packw(state, ukernel, nr, kr, sr);
}

#define XNN_UKERNEL(arch_flags, ukernel, nr, kr, sr, kblock, nr_scale) \
BENCHMARK_CAPTURE_BGEMM(x8_packw, ukernel##_, ukernel, arch_flags, nr, kr, sr);

#define XNN_GIO_UKERNEL(arch_flags, ukernel, nr, kr, sr, kblock, nr_scale) \
BENCHMARK_CAPTURE_BGEMM(x8_gio_packw, ukernel##_, ukernel, arch_flags, nr, kr, sr);

#include "x8-packw/x8-packw.h"
#undef XNN_UKERNEL

Expand Down
9 changes: 9 additions & 0 deletions cmake/gen/scalar_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,16 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS
src/qs8-dwconv/gen/qs8-dwconv-25p4c-minmax-fp32-scalar-lrintf.c
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u2.c
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u3.c
src/qs8-packw/gen/qs8-packw-x8c4-gemm-gio-scalar.c
src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c
src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-scalar.c
src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c
src/qs8-packw/gen/qs8-packw-x16c4-gemm-gio-scalar.c
src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c
src/qs8-packw/gen/qs8-packw-x16c8-gemm-gio-scalar.c
src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c
src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c
src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-imagic.c
Expand Down Expand Up @@ -689,7 +695,9 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS
src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-fmagic.c
src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-imagic.c
src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-lrintf.c
src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-gio-scalar.c
src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-scalar.c
src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-gio-scalar.c
src/qs8-rsum/gen/qs8-rsum-scalar-u1.c
src/qs8-rsum/gen/qs8-rsum-scalar-u2.c
src/qs8-vadd/gen/qs8-vadd-minmax-scalar-u2.c
Expand Down Expand Up @@ -828,6 +836,7 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS
src/x8-packw/gen/x8-packw-x2-gemm-goi-scalar-u4.c
src/x8-packw/gen/x8-packw-x4-gemm-goi-scalar-u4.c
src/x8-packw/gen/x8-packw-x8-gemm-goi-scalar-u4.c
src/x8-packw/gen/x8-packw-x8c8-gemm-gio-scalar.c
src/x8-packw/gen/x8-packw-x16-gemm-goi-scalar-u4.c
src/x8-packw/gen/x8-packw-x32-gemm-goi-scalar-u4.c
src/x8-transposec/gen/x8-transposec-1x2-scalar-int.c
Expand Down
9 changes: 9 additions & 0 deletions gen/scalar_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,16 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [
"src/qs8-dwconv/gen/qs8-dwconv-25p4c-minmax-fp32-scalar-lrintf.c",
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u2.c",
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u3.c",
"src/qs8-packw/gen/qs8-packw-x8c4-gemm-gio-scalar.c",
"src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c",
"src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-scalar.c",
"src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c",
"src/qs8-packw/gen/qs8-packw-x16c4-gemm-gio-scalar.c",
"src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c",
"src/qs8-packw/gen/qs8-packw-x16c8-gemm-gio-scalar.c",
"src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c",
"src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c",
"src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c",
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-imagic.c",
Expand Down Expand Up @@ -686,7 +692,9 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [
"src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-fmagic.c",
"src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-imagic.c",
"src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-lrintf.c",
"src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-gio-scalar.c",
"src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-scalar.c",
"src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-gio-scalar.c",
"src/qs8-rsum/gen/qs8-rsum-scalar-u1.c",
"src/qs8-rsum/gen/qs8-rsum-scalar-u2.c",
"src/qs8-vadd/gen/qs8-vadd-minmax-scalar-u2.c",
Expand Down Expand Up @@ -825,6 +833,7 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [
"src/x8-packw/gen/x8-packw-x2-gemm-goi-scalar-u4.c",
"src/x8-packw/gen/x8-packw-x4-gemm-goi-scalar-u4.c",
"src/x8-packw/gen/x8-packw-x8-gemm-goi-scalar-u4.c",
"src/x8-packw/gen/x8-packw-x8c8-gemm-gio-scalar.c",
"src/x8-packw/gen/x8-packw-x16-gemm-goi-scalar-u4.c",
"src/x8-packw/gen/x8-packw-x32-gemm-goi-scalar-u4.c",
"src/x8-transposec/gen/x8-transposec-1x2-scalar-int.c",
Expand Down
14 changes: 14 additions & 0 deletions scripts/generate-x8-packw.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=0
tools/xngen src/x8-packw/kr-scalar.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -D IZP=128 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-scalar.c &
tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=128 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-scalar.c &

### GIO packing
tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -o src/x8-packw/gen/x8-packw-x8c8-gemm-gio-scalar.c &

tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=8 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -o src/qs8-packw/gen/qs8-packw-x8c4-gemm-gio-scalar.c &
tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=16 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -o src/qs8-packw/gen/qs8-packw-x16c4-gemm-gio-scalar.c &
tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=32 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -o src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c &
tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=64 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -o src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c &

tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-scalar.c &
tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-gio-scalar.c &

tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-gio-scalar.c &
tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-gio-scalar.c &

### AVXVNNI micro-kernels
### C8 packing
tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c &
Expand Down
Loading

0 comments on commit ee5baf3

Please sign in to comment.