From f0aa7a9883e2c81fb42125b6fa539fada75ffc4d Mon Sep 17 00:00:00 2001 From: Frank Barchard Date: Tue, 30 Jul 2024 18:01:30 -0700 Subject: [PATCH] Remove extended weights support PiperOrigin-RevId: 657795516 --- bench/gemm-benchmark.cc | 35 ++------- bench/gemm-benchmark.h | 26 +------ src/packing.c | 64 ---------------- src/xnnpack/pack.h | 14 ---- test/gemm-microkernel-tester.cc | 42 ++++------- test/gemm-microkernel-tester.h | 10 --- tools/generate-gemm-test.py | 125 +------------------------------- 7 files changed, 24 insertions(+), 292 deletions(-) diff --git a/bench/gemm-benchmark.cc b/bench/gemm-benchmark.cc index 885c99cba639..90b0921aa1f6 100644 --- a/bench/gemm-benchmark.cc +++ b/bench/gemm-benchmark.cc @@ -32,8 +32,7 @@ void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm, xnn_init_qs8_conv_minmax_params_fn init_params, xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr, - size_t sr, benchmark::utils::IsaCheckFunction isa_check, - bool extended_weights) { + size_t sr, benchmark::utils::IsaCheckFunction isa_check) { if (isa_check != nullptr && !isa_check(state)) { return; } @@ -61,10 +60,8 @@ void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm, std::vector b(nc); std::generate(b.begin(), b.end(), std::ref(i32rng)); - const size_t w_element_size = - extended_weights ? sizeof(int16_t) : sizeof(int8_t); - const size_t w_size = - nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size; + const size_t w_element_size = sizeof(int8_t); + const size_t w_size = nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size; const size_t c_elements = mc * nc; const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp( benchmark::utils::GetMaxCacheSize(), @@ -127,8 +124,7 @@ void GEMMBenchmark(benchmark::State& state, xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm, xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr, - size_t sr, benchmark::utils::IsaCheckFunction isa_check, - bool extended_weights) { + size_t sr, benchmark::utils::IsaCheckFunction isa_check) { if (isa_check != nullptr && !isa_check(state)) { return; } @@ -156,10 +152,8 @@ void GEMMBenchmark(benchmark::State& state, std::vector b(nc); std::generate(b.begin(), b.end(), std::ref(i32rng)); - const size_t w_element_size = - extended_weights ? sizeof(int16_t) : sizeof(int8_t); - const size_t w_size = - nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size; + const size_t w_element_size = sizeof(int8_t); + const size_t w_size = nc_stride * sizeof(int32_t) + kc_stride * nc_stride * w_element_size; const size_t c_elements = mc * nc; const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp( benchmark::utils::GetMaxCacheSize(), @@ -216,23 +210,6 @@ void GEMMBenchmark(benchmark::State& state, benchmark::Counter::kIsRate); } -void GEMMBenchmark(benchmark::State& state, - xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm, - xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr, - size_t sr, benchmark::utils::IsaCheckFunction isa_check) { - return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr, - isa_check, /*extended_weights=*/false); -} - -void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm, - xnn_init_qs8_conv_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr, - size_t sr, benchmark::utils::IsaCheckFunction isa_check) { - return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr, - isa_check, /*extended_weights=*/false); -} - void GEMMBenchmark(benchmark::State& state, xnn_qd8_f16_qc8w_gemm_ukernel_fn gemm, xnn_init_f16_minmax_params_fn init_params, diff --git a/bench/gemm-benchmark.h b/bench/gemm-benchmark.h index 263fb9f26595..4ff5b91eb69b 100644 --- a/bench/gemm-benchmark.h +++ b/bench/gemm-benchmark.h @@ -23,35 +23,13 @@ void GEMMBenchmark(benchmark::State& state, xnn_qs8_gemm_minmax_ukernel_fn gemm, xnn_init_qs8_conv_minmax_params_fn init_params, xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr, - size_t sr, benchmark::utils::IsaCheckFunction isa_check, - bool extended_weights); + size_t sr, benchmark::utils::IsaCheckFunction isa_check); void GEMMBenchmark(benchmark::State& state, xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm, xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, size_t kr, - size_t sr, benchmark::utils::IsaCheckFunction isa_check, - bool extended_weights); - -static void GEMMBenchmark(benchmark::State& state, - xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm, - xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, - size_t kr, size_t sr, - benchmark::utils::IsaCheckFunction isa_check) { - return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr, - isa_check, /*extended_weights=*/false); -} - -static void GEMMBenchmark(benchmark::State& state, - xnn_qs8_gemm_minmax_ukernel_fn gemm, - xnn_init_qs8_conv_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack, size_t mr, size_t nr, - size_t kr, size_t sr, - benchmark::utils::IsaCheckFunction isa_check) { - return GEMMBenchmark(state, gemm, init_params, pack, mr, nr, kr, sr, - isa_check, /*extended_weights=*/false); -} + size_t sr, benchmark::utils::IsaCheckFunction isa_check); void GEMMBenchmark(benchmark::State& state, xnn_qd8_f16_qc8w_gemm_ukernel_fn gemm, diff --git a/src/packing.c b/src/packing.c index 2829acbf0588..0de27422d871 100644 --- a/src/packing.c +++ b/src/packing.c @@ -887,70 +887,6 @@ void xnn_pack_f32_qc4w_gemm_goi_w( } while (--g != 0); } -void xnn_pack_qs8_gemm_xw_goi_w( - size_t g, - size_t nc, - size_t kc, - size_t nr, - size_t kr, - size_t sr, - const int8_t* k, - const int32_t* b, - const float* scale, - void* packed_weights, - size_t extra_bytes, - const struct xnn_qs8_packing_params* params) -{ - assert(g != 0); - assert(nr >= sr); - assert(k != NULL); - assert(packed_weights != NULL); - - const size_t skr = sr * kr; - const uint32_t izp = (uint32_t) params->input_zero_point; - do { - for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) { - const size_t nr_block_size = min(nc - nr_block_start, nr); - int32_t* packed_b = (int32_t*) packed_weights; - if XNN_LIKELY(b != NULL) { - for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { - unaligned_store_s32(packed_weights, b[nr_block_start + nr_block_offset]); - packed_weights = (void*) ((uintptr_t) packed_weights + sizeof(int32_t)); - } - } else { - size_t n = nr_block_size; - do { - unaligned_store_s32(packed_weights, 0); - packed_weights = (void*) ((uintptr_t) packed_weights + sizeof(int32_t)); - } while (--n != 0); - } - packed_weights = (void*) ((uintptr_t) packed_weights + (nr - nr_block_size) * sizeof(int32_t)); - - for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr); kr_block_start += kr) { - for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { - uint32_t ksum = 0; - for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) { - const size_t kc_idx = round_down_po2(kr_block_start, skr) + ((kr_block_start + kr_block_offset + nr_block_offset * kr) & (skr - 1)); - if (kc_idx < kc) { - const int8_t kv = k[(nr_block_start + nr_block_offset) * kc + kc_idx]; - ksum += (uint32_t) kv; - ((int16_t*) packed_weights)[kr_block_offset] = (int16_t) kv; - } - } - unaligned_indexed_store_u32(packed_b, nr_block_offset, unaligned_indexed_load_u32(packed_b, nr_block_offset) - ksum * izp); - packed_weights = (int16_t*) packed_weights + kr; - } - packed_weights = (int16_t*) packed_weights + (nr - nr_block_size) * kr; - } - packed_weights = (void*) ((uintptr_t) packed_weights + extra_bytes); - } - k += nc * kc; - if XNN_UNPREDICTABLE(b != NULL) { - b += nc; - } - } while (--g != 0); -} - void xnn_pack_f32_gemm_gio_w( size_t g, size_t nc, diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index 0c0619893c37..a655e3b9f0d0 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -295,20 +295,6 @@ XNN_INTERNAL void xnn_pack_f32_qs8w_gemm_goi_w( size_t extra_bytes, const void* params); -XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w( - size_t g, - size_t nc, - size_t kc, - size_t nr, - size_t kr, - size_t sr, - const int8_t* kernel, - const int32_t* bias, - const float* scale, - void* packed_weights, - size_t extra_bytes, - const struct xnn_qs8_packing_params* params); - XNN_INTERNAL void xnn_pack_f32_gemm_gio_w( size_t g, size_t nc, diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index e263b1b0172f..a94dfa506385 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -622,7 +622,6 @@ void GemmMicrokernelTester::Test( std::vector b(n() * k()); std::vector bias(n()); std::vector> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t)); - std::vector> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t)); std::vector c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); std::vector acc(m() * n()); std::vector scale(n()); @@ -636,7 +635,7 @@ void GemmMicrokernelTester::Test( std::fill(packed_w.begin(), packed_w.end(), 0); const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) }; - void* const packed_data = extended_weights() ? static_cast(packed_xw.data()) : packed_w.data(); + void* const packed_data = packed_w.data(); pack(/*g=*/1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), /*scale=*/nullptr, packed_data, nr() * sizeof(float), &packing_params); @@ -666,7 +665,7 @@ void GemmMicrokernelTester::Test( scale[n_index] = 1.0f / c_scale; } - const size_t type_size = extended_weights() ? sizeof(int16_t): sizeof(int8_t); + const size_t type_size = sizeof(int8_t); xnn_init_qs8_qc8w_scale_fp32_params( n(), nr(), nr(), nr() * (packed_k() * type_size + (sizeof(int32_t) + sizeof(float))), @@ -1865,7 +1864,6 @@ void GemmMicrokernelTester::Test( std::vector b(n() * k()); std::vector bias(n()); std::vector> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t)); - std::vector> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t)); std::vector c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); std::vector acc(m() * n()); std::vector c_ref(m() * n()); @@ -1878,7 +1876,7 @@ void GemmMicrokernelTester::Test( std::fill(packed_w.begin(), packed_w.end(), 0); const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) }; - void* const packed_data = extended_weights() ? static_cast(packed_xw.data()) : packed_w.data(); + void* const packed_data = packed_w.data(); pack(/*g=*/1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), /*scale=*/nullptr, packed_data, /*extra_bytes=*/0, &packing_params); @@ -4411,7 +4409,6 @@ void GemmMicrokernelTester::Test( std::vector b(n() * k()); std::vector bias(n()); std::vector> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t)); - std::vector> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t)); std::vector c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); std::vector acc(m() * n()); std::vector scale(n()); @@ -4425,7 +4422,7 @@ void GemmMicrokernelTester::Test( std::fill(packed_w.begin(), packed_w.end(), 0); const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) }; - void* const packed_data = extended_weights() ? static_cast(packed_xw.data()) : packed_w.data(); + void* const packed_data = packed_w.data(); pack(/*g=*/1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), /*scale=*/nullptr, packed_data, nr() * sizeof(float), &packing_params); @@ -4455,23 +4452,13 @@ void GemmMicrokernelTester::Test( scale[n_index] = 1.0f / c_scale; } - if (extended_weights()) { - xnn_init_qs8_qc8w_scale_fp32_params( - n(), nr(), nr(), - nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), - nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), - 0, - scale.data(), - (void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t)))); - } else { - xnn_init_qs8_qc8w_scale_fp32_params( - n(), nr(), nr(), - nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), - nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), - 0, - scale.data(), - (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t)))); - } + xnn_init_qs8_qc8w_scale_fp32_params( + n(), nr(), nr(), + nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), + nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), + 0, + scale.data(), + (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t)))); union xnn_qs8_qc8w_conv_minmax_params minmax_params; init_params(&minmax_params, @@ -4487,7 +4474,7 @@ void GemmMicrokernelTester::Test( gemm( m(), n(), k(), a.data(), a_stride() * sizeof(int8_t), - extended_weights() ? static_cast(packed_xw.data()) : static_cast(packed_w.data()), + static_cast(packed_w.data()), c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t), &minmax_params); @@ -4680,7 +4667,6 @@ void GemmMicrokernelTester::Test( std::vector b(n() * k()); std::vector bias(n()); std::vector> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t)); - std::vector> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t)); std::vector c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); std::vector acc(m() * n()); std::vector c_ref(m() * n()); @@ -4693,7 +4679,7 @@ void GemmMicrokernelTester::Test( std::fill(packed_w.begin(), packed_w.end(), 0); const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) }; - void* const packed_data = extended_weights() ? static_cast(packed_xw.data()) : packed_w.data(); + void* const packed_data = packed_w.data(); pack(/*g=*/1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), /*scale=*/nullptr, packed_data, /*extra_bytes=*/0, &packing_params); @@ -4732,7 +4718,7 @@ void GemmMicrokernelTester::Test( gemm( m(), n(), k(), a.data(), a_stride() * sizeof(int8_t), - extended_weights() ? static_cast(packed_xw.data()) : static_cast(packed_w.data()), + static_cast(packed_w.data()), c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t), &quantization_params); diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index b9a0d0f266f5..3a06ce027b42 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -207,15 +207,6 @@ class GemmMicrokernelTester { return this->zero_index_; } - GemmMicrokernelTester& extended_weights(bool extended_weights) { - this->extended_weights_ = extended_weights; - return *this; - } - - bool extended_weights() const { - return this->extended_weights_; - } - GemmMicrokernelTester& iterations(size_t iterations) { this->iterations_ = iterations; return *this; @@ -503,7 +494,6 @@ class GemmMicrokernelTester { uint8_t qmax_{255}; size_t a_offset_{0}; size_t zero_index_{SIZE_MAX}; - bool extended_weights_{false}; size_t iterations_{15}; bool known_nc_mod_nr_{true}; bool relu_{false}; diff --git a/tools/generate-gemm-test.py b/tools/generate-gemm-test.py index e5329421d972..f127d77767cd 100755 --- a/tools/generate-gemm-test.py +++ b/tools/generate-gemm-test.py @@ -39,7 +39,6 @@ def split_ukernel_name(name): common_name, target_name = name.split("__", 1) common_parts = common_name.split("_") - xw = "gemm_xw_" in common_name param_spec = common_parts[-1] if "s" in param_spec: param_spec, sr = param_spec.split("s", 1) @@ -67,29 +66,7 @@ def split_ukernel_name(name): requantization = common_parts[-3] if requantization not in ["fp32", "rndnu"]: requantization = None - return mr, nr, kr, sr, mr_packed, xw, vector_tile, requantization, arch, isa, assembly - - -GEMM_BENCH_CODE_XW = """\ -static void ${UKERNEL_NAME}(benchmark::State& state, const char* net) { - GEMMBenchmark(state, - ${GEMM}, - $if INIT_PARAMS is not None: - ${INIT_PARAMS}, - $if PACK_FN is not None: - ${PACK_FN}, - /*mr=*/${MR}, /*nr=*/${NR}${NR_SCALE}, /*kr=*/${KR}, /*sr=*/${SR}, - $if ISA_CHECK: - benchmark::utils::${ISA_CHECK}, - $else: - /*isa_check=*/nullptr, - /*extended_weights=*/true); -}\n -$if KERNELTYPE in ['qb4w']: - BENCHMARK_GEMM_BL(${UKERNEL_NAME}) -$else: - BENCHMARK_GEMM(${UKERNEL_NAME}) -""" + return mr, nr, kr, sr, mr_packed, vector_tile, requantization, arch, isa, assembly GEMM_BENCH_CODE = """\ $if CPP_CHECK: @@ -148,8 +125,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs, tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -160,8 +135,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "strided_cn", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block) .cn_stride(xnnpack::NextPrime(nr + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -173,8 +146,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block) .a_stride(xnnpack::NextPrime(k_block + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -186,8 +157,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .k(k_block).iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -199,8 +168,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_subtile_m", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .n(nr).k(k_block).iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -211,8 +178,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_subtile_n", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).k(k_block).iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -224,8 +189,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kb2s, tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block * 2) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -236,8 +199,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kb2s + "_strided_a", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block * 2) .a_stride(xnnpack::NextPrime(k_block * 2 + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -249,8 +210,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_eq_" + kb2s + "_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .k(k_block * 2).iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -264,8 +223,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_lt_" + akbs, tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -277,8 +234,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_lt_" + akbs + "_strided_a", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr) .a_stride(xnnpack::NextPrime(adj_k_block + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -291,8 +246,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_lt_" + akbs + "_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -306,8 +259,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_gt_" + akbs, tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -319,8 +270,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_gt_" + akbs + "_strided_a", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr) .a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -333,8 +282,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_gt_" + akbs + "_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -348,8 +295,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_div_" + kbs, tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -361,8 +306,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_div_" + kbs + "_strided_a", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr) .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -375,8 +318,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "k_div_" + kbs + "_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -390,8 +331,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs, tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -407,8 +346,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "unknown_nc_mod_nr", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).known_nc_mod_nr(false) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -423,8 +360,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "relu", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block).relu(true) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -435,8 +370,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_cn", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr) .cn_stride(xnnpack::NextPrime(nr + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -455,8 +388,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr) .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -473,8 +404,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -490,8 +419,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_div_" + nrs, tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -504,8 +431,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_cn", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr) .cn_stride(xnnpack::NextPrime(nr + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -519,8 +444,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr) .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -534,8 +457,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -549,8 +470,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "small_kernel", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).ks(3) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -561,8 +480,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "small_kernel_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .ks(3).iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -575,8 +492,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_small_kernel", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).ks(3) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -591,8 +506,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_small_kernel", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).ks(3) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -605,8 +518,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "strided_cm_subtile", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .mr(mr).nr(nr).kr(kr).sr(sr) .cm_stride(xnnpack::NextPrime(nr + 1)) .iterations(1) @@ -622,8 +533,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "a_offset", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).ks(3) .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -635,8 +544,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "zero", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).ks(3) .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -651,8 +558,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "qmin", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block).qmin(128) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -662,8 +567,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "qmax", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block).qmax(128) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) @@ -673,8 +576,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "strided_cm", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block) .cm_stride(xnnpack::NextPrime(nr + 1)) $if KERNELTYPE in ['qb4w', 'qc4w']: @@ -686,8 +587,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "no_a_zero_point", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).a_zero_point(0) , test_func, isa_check) .loop_k(1, k_block * 3, k_block + 1)); @@ -695,24 +594,18 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "no_b_zero_point", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).b_zero_point(0) , test_func, isa_check) .loop_k(1, k_block * 3, k_block + 1)); gemm_tests.push_back(GemmTestParams( "b_zero_point", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block) , test_func, isa_check) .loop_bzp(0, 255)); gemm_tests.push_back(GemmTestParams( "no_zero_point", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr) .a_zero_point(0) .b_zero_point(0) @@ -722,8 +615,6 @@ def split_ukernel_name(name): gemm_tests.push_back(GemmTestParams( "bl", tester.clone() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .m(mr).n(nr).k(k_block * 12) .b_zero_point(8) , test_func, isa_check) @@ -767,8 +658,6 @@ def split_ukernel_name(name): for (uint32_t m = 1; m <= max_mr; m++) { for (size_t k = 1; k <= ${KBLOCK * 2}; k += 1) { GemmMicrokernelTester() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .mr(max_mr) $if NR > 1: .nr(${NR}) @@ -795,8 +684,6 @@ def split_ukernel_name(name): ${ISA_CHECK}; const std::vector fused_operators = { {xnn_post_operation_type_hardswish} }; GemmMicrokernelTester() - $if EXTENDED_WEIGHTS: - .extended_weights(true) $if MR > 1: .mr(${MR}) $if NR > 1: @@ -822,8 +709,6 @@ def split_ukernel_name(name): const std::vector fused_operators = { {xnn_post_operation_type_hardswish} }; for (uint32_t max_mr = 1; max_mr < ${MR}; max_mr++) { GemmMicrokernelTester() - $if EXTENDED_WEIGHTS: - .extended_weights(true) .mr(max_mr) $if NR > 1: .nr(${NR}) @@ -877,7 +762,6 @@ def generate_test_cases( kr, sr, mr_packed, - xw, k_block, vector_tile, init_fn, @@ -900,7 +784,6 @@ def generate_test_cases( kr: KR parameter of the GEMM micro-kernel. sr: SR parameter of the GEMM micro-kernel. mr_packed: Optional MR parameter for the left-hand packing function. - xw: boolean indicator for microkernel with extended weights. k_block: Number of K values processed per one iteration of the main loop of the micro-kernel. vector_tile: Indicates if vector tile for NR is specified in vectors rather @@ -993,7 +876,6 @@ def generate_test_cases( "KR": kr, "SR": sr, "MR_PACKED": mr_packed, - "EXTENDED_WEIGHTS": xw, "KBLOCK": k_block, "NR_SCALE": nr_scale, "ADJKBLOCK": 2 * k_block if is_pipelined else k_block, @@ -1011,7 +893,7 @@ def generate_test_cases( test_case = xngen.preprocess(GEMM_TEST_CODE, test_args) benchmark = xngen.preprocess( - GEMM_BENCH_CODE_XW if xw else GEMM_BENCH_CODE, + GEMM_BENCH_CODE, { "UKERNEL_NAME": ukernel_name, "GEMM": ukernel, @@ -1026,7 +908,6 @@ def generate_test_cases( "SR": sr, "MR_PACKED": mr_packed, "NR_SCALE": nr_scale, - "EXTENDED_WEIGHTS": xw, "ISA_CHECK": xnncommon.generate_isa_utilcheck_macro(isa), "CPP_CHECK": cpp_check, }, @@ -1127,7 +1008,6 @@ def main(args): kr, sr, mr_packed, - xw, vector_tile, requantization, arch, @@ -1142,7 +1022,6 @@ def main(args): kr, sr, mr_packed, - xw, k_block, vector_tile, init_fn,