Skip to content

Commit

Permalink
Remove extended weights support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657795516
  • Loading branch information
fbarchard authored and xnnpack-bot committed Jul 31, 2024
1 parent d3806ed commit f0aa7a9
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 292 deletions.
35 changes: 6 additions & 29 deletions bench/gemm-benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr,
size_t sr, benchmark::utils::IsaCheckFunction isa_check,
bool extended_weights) {
size_t sr, benchmark::utils::IsaCheckFunction isa_check) {
if (isa_check != nullptr && !isa_check(state)) {
return;
}
Expand Down Expand Up @@ -61,10 +60,8 @@ void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm,
std::vector<int32_t> b(nc);
std::generate(b.begin(), b.end(), std::ref(i32rng));

const size_t w_element_size =
extended_weights ? sizeof(int16_t) : sizeof(int8_t);
const size_t w_size =
nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size;
const size_t w_element_size = sizeof(int8_t);
const size_t w_size = nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size;
const size_t c_elements = mc * nc;
const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp<size_t>(
benchmark::utils::GetMaxCacheSize(),
Expand Down Expand Up @@ -127,8 +124,7 @@ void GEMMBenchmark(benchmark::State& state,
xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_qc8w_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr,
size_t sr, benchmark::utils::IsaCheckFunction isa_check,
bool extended_weights) {
size_t sr, benchmark::utils::IsaCheckFunction isa_check) {
if (isa_check != nullptr && !isa_check(state)) {
return;
}
Expand Down Expand Up @@ -156,10 +152,8 @@ void GEMMBenchmark(benchmark::State& state,
std::vector<int32_t> b(nc);
std::generate(b.begin(), b.end(), std::ref(i32rng));

const size_t w_element_size =
extended_weights ? sizeof(int16_t) : sizeof(int8_t);
const size_t w_size =
nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size;
const size_t w_element_size = sizeof(int8_t);
const size_t w_size = nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size;
const size_t c_elements = mc * nc;
const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp<size_t>(
benchmark::utils::GetMaxCacheSize(),
Expand Down Expand Up @@ -216,23 +210,6 @@ void GEMMBenchmark(benchmark::State& state,
benchmark::Counter::kIsRate);
}

void GEMMBenchmark(benchmark::State& state,
xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_qc8w_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr,
size_t sr, benchmark::utils::IsaCheckFunction isa_check) {
return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr,
isa_check, /*extended_weights=*/false);
}

void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr,
size_t sr, benchmark::utils::IsaCheckFunction isa_check) {
return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr,
isa_check, /*extended_weights=*/false);
}

void GEMMBenchmark(benchmark::State& state,
xnn_qd8_f16_qc8w_gemm_ukernel_fn gemm,
xnn_init_f16_minmax_params_fn init_params,
Expand Down
26 changes: 2 additions & 24 deletions bench/gemm-benchmark.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,13 @@
void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr,
size_t sr, benchmark::utils::IsaCheckFunction isa_check,
bool extended_weights);
size_t sr, benchmark::utils::IsaCheckFunction isa_check);

void GEMMBenchmark(benchmark::State& state,
xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_qc8w_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr,
size_t sr, benchmark::utils::IsaCheckFunction isa_check,
bool extended_weights);

static void GEMMBenchmark(benchmark::State& state,
xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_qc8w_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr,
size_t kr, size_t sr,
benchmark::utils::IsaCheckFunction isa_check) {
return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr,
isa_check, /*extended_weights=*/false);
}

static void GEMMBenchmark(benchmark::State& state,
xnn_qs8_gemm_minmax_ukernel_fn gemm,
xnn_init_qs8_conv_minmax_params_fn init_params,
xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr,
size_t kr, size_t sr,
benchmark::utils::IsaCheckFunction isa_check) {
return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr,
isa_check, /*extended_weights=*/false);
}
size_t sr, benchmark::utils::IsaCheckFunction isa_check);

void GEMMBenchmark(benchmark::State& state,
xnn_qd8_f16_qc8w_gemm_ukernel_fn gemm,
Expand Down
64 changes: 0 additions & 64 deletions src/packing.c
Original file line number Diff line number Diff line change
Expand Up @@ -887,70 +887,6 @@ void xnn_pack_f32_qc4w_gemm_goi_w(
} while (--g != 0);
}

void xnn_pack_qs8_gemm_xw_goi_w(
size_t g,
size_t nc,
size_t kc,
size_t nr,
size_t kr,
size_t sr,
const int8_t* k,
const int32_t* b,
const float* scale,
void* packed_weights,
size_t extra_bytes,
const struct xnn_qs8_packing_params* params)
{
assert(g != 0);
assert(nr >= sr);
assert(k != NULL);
assert(packed_weights != NULL);

const size_t skr = sr * kr;
const uint32_t izp = (uint32_t) params->input_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
int32_t* packed_b = (int32_t*) packed_weights;
if XNN_LIKELY(b != NULL) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
unaligned_store_s32(packed_weights, b[nr_block_start + nr_block_offset]);
packed_weights = (void*) ((uintptr_t) packed_weights + sizeof(int32_t));
}
} else {
size_t n = nr_block_size;
do {
unaligned_store_s32(packed_weights, 0);
packed_weights = (void*) ((uintptr_t) packed_weights + sizeof(int32_t));
} while (--n != 0);
}
packed_weights = (void*) ((uintptr_t) packed_weights + (nr - nr_block_size) * sizeof(int32_t));

for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr); kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
uint32_t ksum = 0;
for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
const size_t kc_idx = round_down_po2(kr_block_start, skr) + ((kr_block_start + kr_block_offset + nr_block_offset * kr) & (skr - 1));
if (kc_idx < kc) {
const int8_t kv = k[(nr_block_start + nr_block_offset) * kc + kc_idx];
ksum += (uint32_t) kv;
((int16_t*) packed_weights)[kr_block_offset] = (int16_t) kv;
}
}
unaligned_indexed_store_u32(packed_b, nr_block_offset, unaligned_indexed_load_u32(packed_b, nr_block_offset) - ksum * izp);
packed_weights = (int16_t*) packed_weights + kr;
}
packed_weights = (int16_t*) packed_weights + (nr - nr_block_size) * kr;
}
packed_weights = (void*) ((uintptr_t) packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE(b != NULL) {
b += nc;
}
} while (--g != 0);
}

void xnn_pack_f32_gemm_gio_w(
size_t g,
size_t nc,
Expand Down
14 changes: 0 additions & 14 deletions src/xnnpack/pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,6 @@ XNN_INTERNAL void xnn_pack_f32_qs8w_gemm_goi_w(
size_t extra_bytes,
const void* params);

XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w(
size_t g,
size_t nc,
size_t kc,
size_t nr,
size_t kr,
size_t sr,
const int8_t* kernel,
const int32_t* bias,
const float* scale,
void* packed_weights,
size_t extra_bytes,
const struct xnn_qs8_packing_params* params);

XNN_INTERNAL void xnn_pack_f32_gemm_gio_w(
size_t g,
size_t nc,
Expand Down
42 changes: 14 additions & 28 deletions test/gemm-microkernel-tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,6 @@ void GemmMicrokernelTester::Test(
std::vector<int8_t> b(n() * k());
std::vector<int32_t> bias(n());
std::vector<int8_t, AlignedAllocator<int8_t, XNN_ALLOCATION_ALIGNMENT>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
std::vector<int16_t, AlignedAllocator<int16_t, XNN_ALLOCATION_ALIGNMENT>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
std::vector<int32_t> acc(m() * n());
std::vector<float> scale(n());
Expand All @@ -636,7 +635,7 @@ void GemmMicrokernelTester::Test(

std::fill(packed_w.begin(), packed_w.end(), 0);
const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
void* const packed_data = extended_weights() ? static_cast<void*>(packed_xw.data()) : packed_w.data();
void* const packed_data = packed_w.data();
pack(/*g=*/1, n(), k(), nr(), kr(), sr(),
b.data(), bias.data(), /*scale=*/nullptr, packed_data, nr() * sizeof(float), &packing_params);

Expand Down Expand Up @@ -666,7 +665,7 @@ void GemmMicrokernelTester::Test(
scale[n_index] = 1.0f / c_scale;
}

const size_t type_size = extended_weights() ? sizeof(int16_t): sizeof(int8_t);
const size_t type_size = sizeof(int8_t);
xnn_init_qs8_qc8w_scale_fp32_params(
n(), nr(), nr(),
nr() * (packed_k() * type_size + (sizeof(int32_t) + sizeof(float))),
Expand Down Expand Up @@ -1865,7 +1864,6 @@ void GemmMicrokernelTester::Test(
std::vector<int8_t> b(n() * k());
std::vector<int32_t> bias(n());
std::vector<int8_t, AlignedAllocator<int8_t, XNN_ALLOCATION_ALIGNMENT>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
std::vector<int16_t, AlignedAllocator<int16_t, XNN_ALLOCATION_ALIGNMENT>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
std::vector<int32_t> acc(m() * n());
std::vector<int8_t> c_ref(m() * n());
Expand All @@ -1878,7 +1876,7 @@ void GemmMicrokernelTester::Test(

std::fill(packed_w.begin(), packed_w.end(), 0);
const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
void* const packed_data = extended_weights() ? static_cast<void*>(packed_xw.data()) : packed_w.data();
void* const packed_data = packed_w.data();
pack(/*g=*/1, n(), k(), nr(), kr(), sr(),
b.data(), bias.data(), /*scale=*/nullptr, packed_data, /*extra_bytes=*/0, &packing_params);

Expand Down Expand Up @@ -4411,7 +4409,6 @@ void GemmMicrokernelTester::Test(
std::vector<int8_t> b(n() * k());
std::vector<int32_t> bias(n());
std::vector<int8_t, AlignedAllocator<int8_t, XNN_ALLOCATION_ALIGNMENT>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
std::vector<int16_t, AlignedAllocator<int16_t, XNN_ALLOCATION_ALIGNMENT>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
std::vector<int32_t> acc(m() * n());
std::vector<float> scale(n());
Expand All @@ -4425,7 +4422,7 @@ void GemmMicrokernelTester::Test(

std::fill(packed_w.begin(), packed_w.end(), 0);
const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
void* const packed_data = extended_weights() ? static_cast<void*>(packed_xw.data()) : packed_w.data();
void* const packed_data = packed_w.data();
pack(/*g=*/1, n(), k(), nr(), kr(), sr(),
b.data(), bias.data(), /*scale=*/nullptr, packed_data, nr() * sizeof(float), &packing_params);

Expand Down Expand Up @@ -4455,23 +4452,13 @@ void GemmMicrokernelTester::Test(
scale[n_index] = 1.0f / c_scale;
}

if (extended_weights()) {
xnn_init_qs8_qc8w_scale_fp32_params(
n(), nr(), nr(),
nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))),
nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))),
0,
scale.data(),
(void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t))));
} else {
xnn_init_qs8_qc8w_scale_fp32_params(
n(), nr(), nr(),
nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))),
nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))),
0,
scale.data(),
(void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));
}
xnn_init_qs8_qc8w_scale_fp32_params(
n(), nr(), nr(),
nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))),
nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))),
0,
scale.data(),
(void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));

union xnn_qs8_qc8w_conv_minmax_params minmax_params;
init_params(&minmax_params,
Expand All @@ -4487,7 +4474,7 @@ void GemmMicrokernelTester::Test(
gemm(
m(), n(), k(),
a.data(), a_stride() * sizeof(int8_t),
extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
static_cast<const void*>(packed_w.data()),
c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
&minmax_params);

Expand Down Expand Up @@ -4680,7 +4667,6 @@ void GemmMicrokernelTester::Test(
std::vector<int8_t> b(n() * k());
std::vector<int32_t> bias(n());
std::vector<int8_t, AlignedAllocator<int8_t, XNN_ALLOCATION_ALIGNMENT>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
std::vector<int16_t, AlignedAllocator<int16_t, XNN_ALLOCATION_ALIGNMENT>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
std::vector<int32_t> acc(m() * n());
std::vector<int8_t> c_ref(m() * n());
Expand All @@ -4693,7 +4679,7 @@ void GemmMicrokernelTester::Test(

std::fill(packed_w.begin(), packed_w.end(), 0);
const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
void* const packed_data = extended_weights() ? static_cast<void*>(packed_xw.data()) : packed_w.data();
void* const packed_data = packed_w.data();
pack(/*g=*/1, n(), k(), nr(), kr(), sr(),
b.data(), bias.data(), /*scale=*/nullptr, packed_data, /*extra_bytes=*/0, &packing_params);

Expand Down Expand Up @@ -4732,7 +4718,7 @@ void GemmMicrokernelTester::Test(
gemm(
m(), n(), k(),
a.data(), a_stride() * sizeof(int8_t),
extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
static_cast<const void*>(packed_w.data()),
c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
&quantization_params);

Expand Down
10 changes: 0 additions & 10 deletions test/gemm-microkernel-tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,6 @@ class GemmMicrokernelTester {
return this->zero_index_;
}

GemmMicrokernelTester& extended_weights(bool extended_weights) {
this->extended_weights_ = extended_weights;
return *this;
}

bool extended_weights() const {
return this->extended_weights_;
}

GemmMicrokernelTester& iterations(size_t iterations) {
this->iterations_ = iterations;
return *this;
Expand Down Expand Up @@ -503,7 +494,6 @@ class GemmMicrokernelTester {
uint8_t qmax_{255};
size_t a_offset_{0};
size_t zero_index_{SIZE_MAX};
bool extended_weights_{false};
size_t iterations_{15};
bool known_nc_mod_nr_{true};
bool relu_{false};
Expand Down
Loading

0 comments on commit f0aa7a9

Please sign in to comment.